You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm working with flex attention. For longer sequence lengths e.g. 40320, I get OOMs at the line "scores = (query @ key.transpose(-2, -1)).to(dtype=working_precision)" where it tries to allocate a 40320^2 matrix of floats (48GB) all at once. Would it be possible to batch this score calculation to reduce the maximum memory usage, like what the memory-efficient attention algorithm does. Thanks