diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 591f4fcbf..47a76653a 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -56,8 +56,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), - and the first seqlen elements will be sliced, but dim must match x. + The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim). Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. @@ -68,10 +67,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten """ ndim = x.ndim assert ndim > 1 + batch_size = x.shape[0] seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1]) + shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) @@ -437,9 +436,18 @@ def get_attention_masks( mask_mod, B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]: + freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1) + return ((freqs_cis,), (1,)) + def forward( self, tokens: torch.Tensor, + freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None = None, input_batch: torch.Tensor | None = None, ): @@ -464,7 +472,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks=attention_masks) + h = layer(h, freqs_cis, attention_masks=attention_masks) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65..03f5ae3e7 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -70,3 +70,11 @@ def get_attention_masks( raise NotImplementedError( "This model does not support attention masking/Flex Attention." ) + + def get_order_sensitive_buffers( + self, + batch_size: int, + seq_len: int, + ) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]: + raise NotImplementedError() + return ((), ()) diff --git a/torchtitan/train.py b/torchtitan/train.py index 29e48e003..90a03d6a7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -425,6 +425,11 @@ def forward_backward_step( else None ) + # Get the order sensitive buffers + order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers( + inputs.size(0), inputs.size(1) + ) + # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None @@ -449,6 +454,7 @@ def forward_backward_step( if self.pp_has_first_stage: self.pp_schedule.step( inputs, + *order_sensitive_buffers[0], **extra_inputs, attention_masks=attention_masks, target=targets, @@ -457,6 +463,7 @@ def forward_backward_step( ) else: self.pp_schedule.step( + *order_sensitive_buffers[0], attention_masks=attention_masks, target=targets, losses=losses, @@ -479,7 +486,10 @@ def forward_backward_step( assert len(model_parts) == 1 with self.maybe_enable_amp: pred = model_parts[0]( - inputs, **extra_inputs, attention_masks=attention_masks + inputs, + *order_sensitive_buffers[0], + **extra_inputs, + attention_masks=attention_masks, ) loss = self.loss_fn(pred, labels) # need to free pred before bwd to avoid peaking memory