You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I haven't tested, but this will likely cause wrong block mask being applied to non-first microbatches.
E.g. consider a local batch
batch = [
b0,
b1,
b2,
b3,
]
Block mask mask is created for batch, but after (say, size-1) microbatching mask will be used for 4 different smaller batches [b0], [b1], [b2], [b3]. For [b0] it might be fine, but for the others the mask is wrong.
The solution could be either of
do init_attention_mask after microbatching.
when PP does microbatching, modify the block mask as well.