From 668ac7641576962286a41036503af0609d3b7a29 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 5 Jun 2025 15:31:28 -0700 Subject: [PATCH 1/3] move eval to main train loop --- torchtitan/experiments/flux/train.py | 76 +++++----------------------- torchtitan/train.py | 8 +++ 2 files changed, 22 insertions(+), 62 deletions(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index e982cb46ed..c648ebd737 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -146,67 +146,7 @@ def forward_backward_step( return loss - def train_step( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] - ): - input_dict, labels = next(data_iterator) - - self.optimizers.zero_grad() - - # Keep these variables local to shorten the code as these are - # the major variables that are used in the training loop. - model_parts = self.model_parts - assert len(self.model_parts) == 1 - # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not - model = self.model_parts[0] - - world_mesh = self.world_mesh - parallel_dims = self.parallel_dims - - loss = self.forward_backward_step(input_dict, labels) - - dist_utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], - self.job_config.training.max_norm, - foreach=True, - pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, - ) - self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() - - # log metrics - if not self.metrics_processor.should_log(self.step): - return - - if ( - parallel_dims.dp_replicate_enabled - or parallel_dims.dp_shard_enabled - or parallel_dims.cp_enabled - ): - loss = loss.detach() - ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None - global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg), - ) - else: - global_avg_loss = global_max_loss = loss.item() - - self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) - - # Evaluate the model during training - if ( - self.step % self.job_config.eval.eval_freq == 0 - or self.step == self.job_config.training.steps - ): - model.eval() - # We need to set reshard_after_forward before last forward pass. - # So the model wieghts are sharded the same way for checkpoint saving. - self.eval_step() - model.train() - - def eval_step(self, prompt: str = "A photo of a cat"): + def eval_step(self): """ Evaluate the Flux model. 1) generate and save images every few steps. Currently, we run the eval and on the same @@ -216,13 +156,23 @@ def eval_step(self, prompt: str = "A photo of a cat"): 2) [TODO] Calculate loss with fixed t value on validation set. """ + # NOTE: We put the check inside the function becuase other model's job cofig might not have eval config. + if ( + self.step % self.job_config.eval.eval_freq != 0 + and not self.step == self.job_config.training.steps + ): + return + + model = self.model_parts[0] + model.eval() t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config) + prompt = "A photo of a cat" image = generate_image( device=self.device, dtype=self._dtype, job_config=self.job_config, - model=self.model_parts[0], + model=model, prompt=prompt, # TODO(jianiw): change this to a prompt from validation set autoencoder=self.autoencoder, t5_tokenizer=t5_tokenizer, @@ -247,6 +197,8 @@ def eval_step(self, prompt: str = "A photo of a cat"): if isinstance(module, FSDPModule): module.reshard() + model.train() + if __name__ == "__main__": init_logger() diff --git a/torchtitan/train.py b/torchtitan/train.py index bba1f01e2c..894f85df77 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -462,6 +462,11 @@ def train_step( self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) + def eval_step(self): + """Evaluates the model on the validation dataset. + Currently is not implemented and this is a placeholder.""" + return + @record def train(self): job_config = self.job_config @@ -491,6 +496,9 @@ def train(self): except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break + + self.eval_step() + self.checkpointer.save( self.step, force=(self.step == job_config.training.steps) ) From 7e21a516f0ce267945695f0bb7933e22890db1ef Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 5 Jun 2025 15:52:19 -0700 Subject: [PATCH 2/3] minimal changes --- README.md | 11 +++++----- torchtitan/experiments/flux/train.py | 30 +++++++++++++++++----------- torchtitan/train.py | 8 -------- 3 files changed, 24 insertions(+), 25 deletions(-) diff --git a/README.md b/README.md index 55d6ee9713..dd54765840 100644 --- a/README.md +++ b/README.md @@ -60,11 +60,12 @@ To accelerate contributions to and innovations around torchtitan, we are hosting 7. DDP and HSDP 8. [TorchFT](https://github.com/pytorch/torchft) integration 9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) -10. Flexible learning rate scheduler (warmup-stable-decay) -11. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) -12. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc. -13. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/) -14. [Helper scripts](scripts/) to +10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument in configuration +11. Flexible learning rate scheduler (warmup-stable-decay) +12. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md) +13. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc. +14. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/) +15. [Helper scripts](scripts/) to - download tokenizers from Hugging Face - convert original Llama 3 checkpoints into the expected DCP format - estimate FSDP/HSDP memory usage without materializing the model diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index c648ebd737..40968502f0 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -146,7 +146,7 @@ def forward_backward_step( return loss - def eval_step(self): + def eval_step(self, prompt: str = "A photo of a cat"): """ Evaluate the Flux model. 1) generate and save images every few steps. Currently, we run the eval and on the same @@ -156,15 +156,6 @@ def eval_step(self): 2) [TODO] Calculate loss with fixed t value on validation set. """ - # NOTE: We put the check inside the function becuase other model's job cofig might not have eval config. - if ( - self.step % self.job_config.eval.eval_freq != 0 - and not self.step == self.job_config.training.steps - ): - return - - model = self.model_parts[0] - model.eval() t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config) prompt = "A photo of a cat" @@ -172,7 +163,7 @@ def eval_step(self): device=self.device, dtype=self._dtype, job_config=self.job_config, - model=model, + model=self.model_parts[0], prompt=prompt, # TODO(jianiw): change this to a prompt from validation set autoencoder=self.autoencoder, t5_tokenizer=t5_tokenizer, @@ -197,7 +188,22 @@ def eval_step(self): if isinstance(module, FSDPModule): module.reshard() - model.train() + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + super().train_step(data_iterator) + + # Evaluate the model during training + if ( + self.step % self.job_config.eval.eval_freq == 0 + or self.step == self.job_config.training.steps + ): + model = self.model_parts[0] + model.eval() + # We need to set reshard_after_forward before last forward pass. + # So the model wieghts are sharded the same way for checkpoint saving. + self.eval_step() + model.train() if __name__ == "__main__": diff --git a/torchtitan/train.py b/torchtitan/train.py index 894f85df77..bba1f01e2c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -462,11 +462,6 @@ def train_step( self.metrics_processor.log(self.step, global_avg_loss, global_max_loss) - def eval_step(self): - """Evaluates the model on the validation dataset. - Currently is not implemented and this is a placeholder.""" - return - @record def train(self): job_config = self.job_config @@ -496,9 +491,6 @@ def train(self): except DataloaderStopIteration: logger.warning("Ran out of data; last step was canceled.") break - - self.eval_step() - self.checkpointer.save( self.step, force=(self.step == job_config.training.steps) ) From c4795284174dc0ef6ab426801f4378dcd3416ef1 Mon Sep 17 00:00:00 2001 From: Jiani Wang Date: Thu, 5 Jun 2025 15:54:24 -0700 Subject: [PATCH 3/3] fix prompt --- torchtitan/experiments/flux/train.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 40968502f0..a65d754302 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -158,7 +158,6 @@ def eval_step(self, prompt: str = "A photo of a cat"): t5_tokenizer, clip_tokenizer = build_flux_tokenizer(self.job_config) - prompt = "A photo of a cat" image = generate_image( device=self.device, dtype=self._dtype,