Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions torchtitan/models/llama3/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
for the purpose of broadcasting the frequency tensor during element-wise operations.

The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
and the first seqlen elements will be sliced, but dim must match x.
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).

Args:
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
Expand All @@ -68,10 +67,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
"""
ndim = x.ndim
assert ndim > 1
batch_size = x.shape[0]
seqlen = x.shape[1]
freqs_cis = freqs_cis[0:seqlen]
assert freqs_cis.shape == (seqlen, x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)


Expand Down Expand Up @@ -437,9 +436,18 @@ def get_attention_masks(
mask_mod, B, None, input_batch.shape[1], input_batch.shape[1]
)

def get_order_sensitive_buffers(
self,
batch_size: int,
seq_len: int,
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder what's the benefit of keeping self.freqs_cis.

If seq len changes from iteration to iteration (e.g. in forge), it might be good to keep a central self.freqs_cis instead of computing it each iteration. The other benefit is that we may not want torchtitan model definition to deviate from "original" / "conventional" model definitions too much.

On the other hand, the dependency sounds indirect and error-prone:

  • we create self.freqs_cis in model code
  • then copy it to freqs_cis, which technically is outside the model
  • we then send freqs_cis into model

Would like to hear your thoughts.

Copy link
Contributor Author

@fegin fegin Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-computation is the main reason why I decided to keep self.freqs_cis. I agree it's a bit awkward.

One alternative is to sill keep self.freqs_cis but set it as an optional field (self.freqs_cis: torch.Tensor | None) for bookkeeping only. And we only initialize it in this function. So the creation logic flow (precompute and slicing) is mainly in this function. The model code still provides precompute function. So this way we do not change the code structure too much while keeping the logic together. Not a perfect solution though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's fine to keep the current way for now, as it sounds more lightweight change, and as I mentioned downstream application (e.g. forge, and simple generation) may change seq_len from iteration to iteration, where we can avoid recomputation this way.

return ((freqs_cis,), (1,))

def forward(
self,
tokens: torch.Tensor,
freqs_cis: torch.Tensor,
attention_masks: AttentionMasksType | None = None,
input_batch: torch.Tensor | None = None,
):
Expand All @@ -464,7 +472,7 @@ def forward(
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens

for layer in self.layers.values():
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
h = layer(h, freqs_cis, attention_masks=attention_masks)

h = self.norm(h) if self.norm else h
output = self.output(h) if self.output else h
Expand Down
8 changes: 8 additions & 0 deletions torchtitan/protocols/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,3 +70,11 @@ def get_attention_masks(
raise NotImplementedError(
"This model does not support attention masking/Flex Attention."
)

def get_order_sensitive_buffers(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Naming is a bit vague. I think here we are only targeting "sequence dim" order-sensitive buffers, not the batch dim.

self,
batch_size: int,
seq_len: int,
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add some notes here what does this 2 return values mean? Seems like the first return value is the buffer itself

raise NotImplementedError()
return ((), ())
12 changes: 11 additions & 1 deletion torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,11 @@ def forward_backward_step(
else None
)

# Get the order sensitive buffers
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
inputs.size(0), inputs.size(1)
)

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
Expand All @@ -449,6 +454,7 @@ def forward_backward_step(
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs,
*order_sensitive_buffers[0],
**extra_inputs,
attention_masks=attention_masks,
target=targets,
Expand All @@ -457,6 +463,7 @@ def forward_backward_step(
)
else:
self.pp_schedule.step(
*order_sensitive_buffers[0],
attention_masks=attention_masks,
target=targets,
losses=losses,
Expand All @@ -479,7 +486,10 @@ def forward_backward_step(
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](
inputs, **extra_inputs, attention_masks=attention_masks
inputs,
*order_sensitive_buffers[0],
**extra_inputs,
attention_masks=attention_masks,
)
loss = self.loss_fn(pred, labels)
# need to free pred before bwd to avoid peaking memory
Expand Down
Loading