Skip to content

PP microbatching + block-causal FlexAttention composability issue #1723

@tianyu-l

Description

@tianyu-l

Today init_attention_mask is called before PP does microbatch split
https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L421

I haven't tested, but this will likely cause wrong block mask being applied to non-first microbatches.
E.g. consider a local batch

batch = [
  b0,
  b1,
  b2,
  b3,
]

Block mask mask is created for batch, but after (say, size-1) microbatching mask will be used for 4 different smaller batches [b0], [b1], [b2], [b3]. For [b0] it might be fine, but for the others the mask is wrong.

The solution could be either of

  1. do init_attention_mask after microbatching.
  2. when PP does microbatching, modify the block mask as well.

cc @fegin @H-Huang @drisspg

Metadata

Metadata

Assignees

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions