[Observability 7/7] Timing spans and step tags across all trainers#2607
[Observability 7/7] Timing spans and step tags across all trainers#2607felipemello1 wants to merge 2 commits intogh/felipemello1/23/basefrom
Conversation
[ghstack-poisoned]
| self.states[MODEL].load_state_dict(state_dict) | ||
|
|
||
| @torch.no_grad() | ||
| def save(self, curr_step: int, last_step: bool = False) -> None: |
There was a problem hiding this comment.
make it return a boolean so we know if ckpt happened, so we can add a tag to logs.
There was a problem hiding this comment.
should record_span only happen in trainer.py? because if o/w you can add it in this function
There was a problem hiding this comment.
should record_span only happen in trainer.py?
No, we can use it everywhere.
However, we need init_observability to be called first by the rank, and we probably want 'set_step(step)' called as well, to be added as metadata.
For this specific case, i think that having it explicit in trainer.py is better. Let me know if you disagree
| ) | ||
| self.model_config = model_config | ||
|
|
||
| logger.info( |
There was a problem hiding this comment.
nothing was deleted or changed here. Just indented. Not sure why github is not picking this up.
this PR just adds 'record_spans' to function blocks
| with torch.no_grad(): | ||
| cast(Decoder, m).init_weights(buffer_device=buffer_device) | ||
| m.train() | ||
| cast(BaseModel, model).init_weights(buffer_device=buffer_device) |
There was a problem hiding this comment.
this was not changed. Github is comparing different code lines. Nothing was changed, just indented.
…trainers" ## How to review Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same. ## Summary Wraps every training phase in `record_span` so the gantt chart shows the full training timeline. Adds `add_step_tag` annotations for GC, validation, and checkpoint events. [ ] (Previous PR) 1st phase - Focuses on 'experiment metrics' and excludes 'record_span' and 'record_event'. [x] (This PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'. All four trainers (base, Flux, FT, Forge) have the same span coverage (Trainer, FT, Flux, Forge) - `checkpoint.save()` and `GarbageCollection.run()` return `bool` for step tagging <img width="1499" height="683" alt="image" src="https://github.com/user-attachments/assets/a60a55bf-1daf-4ef3-891d-3e3762dd8880" /> ## Test plan Flux, Forge and Trainer.py were test in main vs this PR. Losses and gradnorm are the same. ```bash # Base trainer with PP NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \ --training.steps 10 --parallelism.data_parallel_shard_degree=2 \ --parallelism.tensor_parallel_degree=2 --parallelism.pipeline_parallel_degree=2 \ --parallelism.pipeline_parallel_schedule=1F1B --debug.seed=42 --debug.deterministic \ --metrics.enable-wandb # Base trainer with validation (no PP) NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \ --training.steps 10 --parallelism.data_parallel_shard_degree=4 \ --parallelism.tensor_parallel_degree=2 --debug.seed=42 --debug.deterministic \ --validator.enable --validator.freq=5 --validator.steps=2 --metrics.enable-wandb # Flux (DP=4 CP=2) NGPU=8 LOG_RANK=0 ./run_train.sh --module flux --config flux_debugmodel \ --training.steps 5 --debug.seed=42 --metrics.enable-wandb \ --parallelism.data_parallel_shard_degree=4 --parallelism.context_parallel_degree=2 \ --tokenizer.test_mode --encoder.random_init \ --encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/ \ --encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/ \ --tokenizer.t5_tokenizer_path tests/assets/tokenizer \ --tokenizer.clip_tokenizer_path tests/assets/tokenizer \ --hf_assets_path tests/assets/tokenizer # Forge (DP=2 TP=2 PP=2) python -m torch.distributed.run --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 \ -m torchtitan.experiments.forge.example_train --module llama3 --config llama3_debugmodel \ --training.steps 5 --debug.seed=42 --debug.deterministic --metrics.enable-wandb \ --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 \ --parallelism.pipeline_parallel_degree=2 --parallelism.pipeline_parallel_schedule=1F1B ``` - [x] Integration: base trainer PP=2 DP=2 TP=2 loss matches main - [x] Integration: base trainer with validation shows training + validation console lines - [x] Integration: Flux DP=4 CP=2 loss matches main - [x] Integration: Forge DP=2 TP=2 PP=2 loss matches main (real loss, not -2.0) - [x] Integration: gantt (PP) has span types across all 8 ranks - [x] Integration: gantt (validation) includes eval spans - [x] Integration: checkpoint tag only appears on steps where checkpoint was saved ### Console output — Base trainer with PP (PP=2 DP=2 TP=2) ``` [rank0]:[titan] 2026-03-16 08:57:13,807 - torchtitan.observability.aggregation - INFO - step: 1 training/loss_mean: 8.028 training/grad_norm_max: 50810.5 trainer_memory/reserved_gib_max: 0.19 trainer_throughput/tps_mean: 361.4 trainer_throughput/tflops_mean: 0.026 trainer_throughput/mfu_pct_mean: 0.0026 [rank0]:[titan] 2026-03-16 08:57:15,323 - torchtitan.observability.aggregation - INFO - step: 5 training/loss_mean: 5.021 training/grad_norm_max: 81269.6 trainer_memory/reserved_gib_max: 0.20 trainer_throughput/tps_mean: 10177.9 trainer_throughput/tflops_mean: 0.73 trainer_throughput/mfu_pct_mean: 0.074 [rank0]:[titan] 2026-03-16 08:57:17,135 - torchtitan.observability.aggregation - INFO - step: 10 training/loss_mean: 3.90 training/grad_norm_max: 52873.5 trainer_memory/reserved_gib_max: 0.20 trainer_throughput/tps_mean: 11729.9 trainer_throughput/tflops_mean: 0.84 trainer_throughput/mfu_pct_mean: 0.085 ``` ### Console output — Base trainer with validation (DP=4 TP=2) ``` [rank0]:[titan] 2026-03-16 08:58:05,206 - torchtitan.observability.aggregation - INFO - step: 1 training/loss_mean: 8.13 training/grad_norm_max: 1.54 trainer_memory/reserved_gib_max: 0.64 trainer_throughput/tps_mean: 1409.4 trainer_throughput/tflops_mean: 0.10 trainer_throughput/mfu_pct_mean: 0.010 [rank0]:[titan] 2026-03-16 08:58:05,207 - torchtitan.observability.aggregation - INFO - validate step: 1 validation/loss_mean: 7.81 validator_memory/reserved_gib_max: 0.64 validator_throughput/tps_mean: 11233.7 [rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - step: 5 training/loss_mean: 5.32 training/grad_norm_max: 2.51 trainer_memory/reserved_gib_max: 0.67 trainer_throughput/tps_mean: 71439.5 trainer_throughput/tflops_mean: 5.11 trainer_throughput/mfu_pct_mean: 0.52 [rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - validate step: 5 validation/loss_mean: 4.78 validator_memory/reserved_gib_max: 0.67 validator_throughput/tps_mean: 42884.5 [rank0]:[titan] 2026-03-16 08:58:07,045 - torchtitan.observability.aggregation - INFO - step: 10 training/loss_mean: 4.022 training/grad_norm_max: 1.88 trainer_memory/reserved_gib_max: 0.67 trainer_throughput/tps_mean: 66496.5 trainer_throughput/tflops_mean: 4.76 trainer_throughput/mfu_pct_mean: 0.48 [rank0]:[titan] 2026-03-16 08:58:07,046 - torchtitan.observability.aggregation - INFO - validate step: 10 validation/loss_mean: 4.034 validator_memory/reserved_gib_max: 0.67 validator_throughput/tps_mean: 45510.2 ``` ### WandB runs | | Main | PR7c | |---|------|------| | Base (PP=2 DP=2 TP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/9ltdu69z | https://wandb.ai/cabernet-team/torchtitan/runs/uxcx428m | | Base (validation) | — | https://wandb.ai/cabernet-team/torchtitan/runs/rsq5vxlx | | Flux (DP=4 CP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/twfvo2ys | https://wandb.ai/cabernet-team/torchtitan/runs/iua7gyq5 | | Forge (DP=2 TP=2 PP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/35ztpneg | https://wandb.ai/cabernet-team/torchtitan/runs/o1pqmpeb | [ghstack-poisoned]
joecummings
left a comment
There was a problem hiding this comment.
I'd like to see timing on using NFS across 1, 2, 4, 8, N (as many as you can get) nodes.
Stack from ghstack (oldest at bottom):
How to review
Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same.
Summary
Wraps every training phase in
record_spanso the gantt chart shows the full training timeline. Addsadd_step_tagannotations for GC, validation, and checkpoint events.[ ] (Previous PR) 1st phase - Focuses on 'experiment metrics' and excludes 'record_span' and 'record_event'.
[x] (This PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'.
All four trainers (base, Flux, FT, Forge) have the same span coverage (Trainer, FT, Flux, Forge)
checkpoint.save()andGarbageCollection.run()returnboolfor step taggingTest plan
Flux, Forge and Trainer.py were test in main vs this PR. Losses and gradnorm are the same.
Console output — Base trainer with PP (PP=2 DP=2 TP=2)
Console output — Base trainer with validation (DP=4 TP=2)
WandB runs