Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,12 @@ To accelerate contributions to and innovations around torchtitan, we are hosting
7. DDP and HSDP
8. [TorchFT](https://github.com/pytorch/torchft) integration
9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
10. Flexible learning rate scheduler (warmup-stable-decay)
11. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
12. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
13. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/)
14. [Helper scripts](scripts/) to
10. Gradient accumulation, enabled by giving an additional `--training.global_batch_size` argument in configuration
11. Flexible learning rate scheduler (warmup-stable-decay)
12. Loss, GPU memory, throughput (tokens/sec), TFLOPs, and MFU displayed and logged via [Tensorboard or Weights & Biases](/docs/metrics.md)
13. [Debugging tools](docs/debugging.md) including CPU/GPU profiling, memory profiling, Flight Recorder, etc.
14. All options easily configured via [toml files](torchtitan/models/llama3/train_configs/)
15. [Helper scripts](scripts/) to
- download tokenizers from Hugging Face
- convert original Llama 3 checkpoints into the expected DCP format
- estimate FSDP/HSDP memory usage without materializing the model
Expand Down
77 changes: 17 additions & 60 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,66 +146,6 @@ def forward_backward_step(

return loss

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
input_dict, labels = next(data_iterator)

self.optimizers.zero_grad()

# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
model_parts = self.model_parts
assert len(self.model_parts) == 1
# explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not
model = self.model_parts[0]

world_mesh = self.world_mesh
parallel_dims = self.parallel_dims

loss = self.forward_backward_step(input_dict, labels)

dist_utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
self.job_config.training.max_norm,
foreach=True,
pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None,
)
self.checkpointer.maybe_wait_for_staging()
self.optimizers.step()
self.lr_schedulers.step()

# log metrics
if not self.metrics_processor.should_log(self.step):
return

if (
parallel_dims.dp_replicate_enabled
or parallel_dims.dp_shard_enabled
or parallel_dims.cp_enabled
):
loss = loss.detach()
ft_pg = self.ft_manager.replicate_pg if self.ft_manager.enabled else None
global_avg_loss, global_max_loss = (
dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg),
dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg),
)
else:
global_avg_loss = global_max_loss = loss.item()

self.metrics_processor.log(self.step, global_avg_loss, global_max_loss)

# Evaluate the model during training
if (
self.step % self.job_config.eval.eval_freq == 0
or self.step == self.job_config.training.steps
):
model.eval()
# We need to set reshard_after_forward before last forward pass.
# So the model wieghts are sharded the same way for checkpoint saving.
self.eval_step()
model.train()

def eval_step(self, prompt: str = "A photo of a cat"):
"""
Evaluate the Flux model.
Expand Down Expand Up @@ -247,6 +187,23 @@ def eval_step(self, prompt: str = "A photo of a cat"):
if isinstance(module, FSDPModule):
module.reshard()

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
super().train_step(data_iterator)

# Evaluate the model during training
if (
self.step % self.job_config.eval.eval_freq == 0
or self.step == self.job_config.training.steps
):
model = self.model_parts[0]
model.eval()
# We need to set reshard_after_forward before last forward pass.
# So the model wieghts are sharded the same way for checkpoint saving.
self.eval_step()
model.train()


if __name__ == "__main__":
init_logger()
Expand Down
Loading