diff --git a/README.md b/README.md index 55d6ee9713..dd54765840 100644 --- a/README.md +++ b/README.md @@ -60,11 +60,12 @@ To accelerate contributions to and innovations around torchtitan, we are hosting 7. DDP and HSDP 8. [TorchFT](https://github.com/pytorch/torchft) integration 9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) -10. Flexible learning rate scheduler (warmup-stable-decay) -11. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) -12. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc. -13. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/) -14. [Helper scripts](scripts/) to +10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument in configuration +11. Flexible learning rate scheduler (warmup-stable-decay) +12. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) +13. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc. +14. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/) +15. [Helper scripts](scripts/) to - download tokenizers from Hugging Face - convert original Llama 3 checkpoints into the expected DCP format - estimate FSDP/HSDP memory usage without materializing the model diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index e982cb46ed..a65d754302 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -146,66 +146,6 @@ def forward_backward_step( return loss - def train_step( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] - ): - input_dict, labels = next(data_iterator) - - self.optimizers.zero_grad() - - # Keep these variables local to shorten the code as these are - # the major variables that are used in the training loop. - model_parts = self.model_parts - assert len(self.model_parts) == 1 - # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not - model = self.model_parts[0] - - world_mesh = self.world_mesh - parallel_dims = self.parallel_dims - - loss = self.forward_backward_step(input_dict, labels) - - dist_utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], - self.job_config.training.max_norm, - foreach=True, - pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, - ) - self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() - - # log metrics - if not self.metrics_processor.should_log(self.step): - return - - if ( - parallel_dims.dp_replicate_enabled - or parallel_dims.dp_shard_enabled - or parallel_dims.cp_enabled - ): - loss = loss.detach() - ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None - global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg), - ) - else: - global_avg_loss = global_max_loss = loss.item() - - self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) - - # Evaluate the model during training - if ( - self.step % self.job_config.eval.eval_freq == 0 - or self.step == self.job_config.training.steps - ): - model.eval() - # We need to set reshard_after_forward before last forward pass. - # So the model wieghts are sharded the same way for checkpoint saving. - self.eval_step() - model.train() - def eval_step(self, prompt: str = "A photo of a cat"): """ Evaluate the Flux model. @@ -247,6 +187,23 @@ def eval_step(self, prompt: str = "A photo of a cat"): if isinstance(module, FSDPModule): module.reshard() + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + super().train_step(data_iterator) + + # Evaluate the model during training + if ( + self.step % self.job_config.eval.eval_freq == 0 + or self.step == self.job_config.training.steps + ): + model = self.model_parts[0] + model.eval() + # We need to set reshard_after_forward before last forward pass. + # So the model wieghts are sharded the same way for checkpoint saving. + self.eval_step() + model.train() + if __name__ == "__main__": init_logger()