Skip to content

[Observability 7/7] Timing spans and step tags across all trainers#2607

Open
felipemello1 wants to merge 2 commits intogh/felipemello1/23/basefrom
gh/felipemello1/23/head
Open

[Observability 7/7] Timing spans and step tags across all trainers#2607
felipemello1 wants to merge 2 commits intogh/felipemello1/23/basefrom
gh/felipemello1/23/head

Conversation

@felipemello1
Copy link

@felipemello1 felipemello1 commented Mar 16, 2026

Stack from ghstack (oldest at bottom):

How to review

Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same.

Summary

Wraps every training phase in record_span so the gantt chart shows the full training timeline. Adds add_step_tag annotations for GC, validation, and checkpoint events.

[ ] (Previous PR) 1st phase - Focuses on 'experiment metrics' and excludes 'record_span' and 'record_event'.
[x] (This PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'.

All four trainers (base, Flux, FT, Forge) have the same span coverage (Trainer, FT, Flux, Forge)

  • checkpoint.save() and GarbageCollection.run() return bool for step tagging
image

Test plan

Flux, Forge and Trainer.py were test in main vs this PR. Losses and gradnorm are the same.

# Base trainer with PP
NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \
  --training.steps 10 --parallelism.data_parallel_shard_degree=2 \
  --parallelism.tensor_parallel_degree=2 --parallelism.pipeline_parallel_degree=2 \
  --parallelism.pipeline_parallel_schedule=1F1B --debug.seed=42 --debug.deterministic \
  --metrics.enable-wandb

# Base trainer with validation (no PP)
NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \
  --training.steps 10 --parallelism.data_parallel_shard_degree=4 \
  --parallelism.tensor_parallel_degree=2 --debug.seed=42 --debug.deterministic \
  --validator.enable --validator.freq=5 --validator.steps=2 --metrics.enable-wandb

# Flux (DP=4 CP=2)
NGPU=8 LOG_RANK=0 ./run_train.sh --module flux --config flux_debugmodel \
  --training.steps 5 --debug.seed=42 --metrics.enable-wandb \
  --parallelism.data_parallel_shard_degree=4 --parallelism.context_parallel_degree=2 \
  --tokenizer.test_mode --encoder.random_init \
  --encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/ \
  --encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/ \
  --tokenizer.t5_tokenizer_path tests/assets/tokenizer \
  --tokenizer.clip_tokenizer_path tests/assets/tokenizer \
  --hf_assets_path tests/assets/tokenizer

# Forge (DP=2 TP=2 PP=2)
python -m torch.distributed.run --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 \
  -m torchtitan.experiments.forge.example_train --module llama3 --config llama3_debugmodel \
  --training.steps 5 --debug.seed=42 --debug.deterministic --metrics.enable-wandb \
  --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 \
  --parallelism.pipeline_parallel_degree=2 --parallelism.pipeline_parallel_schedule=1F1B
  • Integration: base trainer PP=2 DP=2 TP=2 loss matches main
  • Integration: base trainer with validation shows training + validation console lines
  • Integration: Flux DP=4 CP=2 loss matches main
  • Integration: Forge DP=2 TP=2 PP=2 loss matches main (real loss, not -2.0)
  • Integration: gantt (PP) has span types across all 8 ranks
  • Integration: gantt (validation) includes eval spans
  • Integration: checkpoint tag only appears on steps where checkpoint was saved

Console output — Base trainer with PP (PP=2 DP=2 TP=2)

[rank0]:[titan] 2026-03-16 08:57:13,807 - torchtitan.observability.aggregation - INFO - step:  1  training/loss_mean: 8.028  training/grad_norm_max: 50810.5  trainer_memory/reserved_gib_max: 0.19  trainer_throughput/tps_mean: 361.4  trainer_throughput/tflops_mean: 0.026  trainer_throughput/mfu_pct_mean: 0.0026
[rank0]:[titan] 2026-03-16 08:57:15,323 - torchtitan.observability.aggregation - INFO - step:  5  training/loss_mean: 5.021  training/grad_norm_max: 81269.6  trainer_memory/reserved_gib_max: 0.20  trainer_throughput/tps_mean: 10177.9  trainer_throughput/tflops_mean: 0.73  trainer_throughput/mfu_pct_mean: 0.074
[rank0]:[titan] 2026-03-16 08:57:17,135 - torchtitan.observability.aggregation - INFO - step: 10  training/loss_mean: 3.90  training/grad_norm_max: 52873.5  trainer_memory/reserved_gib_max: 0.20  trainer_throughput/tps_mean: 11729.9  trainer_throughput/tflops_mean: 0.84  trainer_throughput/mfu_pct_mean: 0.085

Console output — Base trainer with validation (DP=4 TP=2)

[rank0]:[titan] 2026-03-16 08:58:05,206 - torchtitan.observability.aggregation - INFO - step:  1  training/loss_mean: 8.13  training/grad_norm_max: 1.54  trainer_memory/reserved_gib_max: 0.64  trainer_throughput/tps_mean: 1409.4  trainer_throughput/tflops_mean: 0.10  trainer_throughput/mfu_pct_mean: 0.010
[rank0]:[titan] 2026-03-16 08:58:05,207 - torchtitan.observability.aggregation - INFO - validate step:  1  validation/loss_mean: 7.81  validator_memory/reserved_gib_max: 0.64  validator_throughput/tps_mean: 11233.7
[rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - step:  5  training/loss_mean: 5.32  training/grad_norm_max: 2.51  trainer_memory/reserved_gib_max: 0.67  trainer_throughput/tps_mean: 71439.5  trainer_throughput/tflops_mean: 5.11  trainer_throughput/mfu_pct_mean: 0.52
[rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - validate step:  5  validation/loss_mean: 4.78  validator_memory/reserved_gib_max: 0.67  validator_throughput/tps_mean: 42884.5
[rank0]:[titan] 2026-03-16 08:58:07,045 - torchtitan.observability.aggregation - INFO - step: 10  training/loss_mean: 4.022  training/grad_norm_max: 1.88  trainer_memory/reserved_gib_max: 0.67  trainer_throughput/tps_mean: 66496.5  trainer_throughput/tflops_mean: 4.76  trainer_throughput/mfu_pct_mean: 0.48
[rank0]:[titan] 2026-03-16 08:58:07,046 - torchtitan.observability.aggregation - INFO - validate step: 10  validation/loss_mean: 4.034  validator_memory/reserved_gib_max: 0.67  validator_throughput/tps_mean: 45510.2

WandB runs

Main PR7c
Base (PP=2 DP=2 TP=2) https://wandb.ai/cabernet-team/torchtitan/runs/9ltdu69z https://wandb.ai/cabernet-team/torchtitan/runs/uxcx428m
Base (validation) https://wandb.ai/cabernet-team/torchtitan/runs/rsq5vxlx
Flux (DP=4 CP=2) https://wandb.ai/cabernet-team/torchtitan/runs/twfvo2ys https://wandb.ai/cabernet-team/torchtitan/runs/iua7gyq5
Forge (DP=2 TP=2 PP=2) https://wandb.ai/cabernet-team/torchtitan/runs/35ztpneg https://wandb.ai/cabernet-team/torchtitan/runs/o1pqmpeb

self.states[MODEL].load_state_dict(state_dict)

@torch.no_grad()
def save(self, curr_step: int, last_step: bool = False) -> None:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make it return a boolean so we know if ckpt happened, so we can add a tag to logs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should record_span only happen in trainer.py? because if o/w you can add it in this function

Copy link
Author

@felipemello1 felipemello1 Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should record_span only happen in trainer.py?

No, we can use it everywhere.
However, we need init_observability to be called first by the rank, and we probably want 'set_step(step)' called as well, to be added as metadata.

For this specific case, i think that having it explicit in trainer.py is better. Let me know if you disagree

)
self.model_config = model_config

logger.info(
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nothing was deleted or changed here. Just indented. Not sure why github is not picking this up.

this PR just adds 'record_spans' to function blocks

with torch.no_grad():
cast(Decoder, m).init_weights(buffer_device=buffer_device)
m.train()
cast(BaseModel, model).init_weights(buffer_device=buffer_device)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was not changed. Github is comparing different code lines. Nothing was changed, just indented.

…trainers"


## How to review

Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same.

## Summary

Wraps every training phase in `record_span` so the gantt chart shows the full training timeline. Adds `add_step_tag` annotations for GC, validation, and checkpoint events.

[ ] (Previous PR) 1st phase - Focuses on 'experiment metrics' and excludes 'record_span' and 'record_event'.
[x] (This PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'.

All four trainers (base, Flux, FT, Forge) have the same span coverage (Trainer, FT, Flux, Forge)

- `checkpoint.save()` and `GarbageCollection.run()` return `bool` for step tagging 

<img width="1499" height="683" alt="image" src="https://github.com/user-attachments/assets/a60a55bf-1daf-4ef3-891d-3e3762dd8880" />

## Test plan

Flux, Forge and Trainer.py were test in main vs this PR. Losses and gradnorm are the same.

```bash
# Base trainer with PP
NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \
  --training.steps 10 --parallelism.data_parallel_shard_degree=2 \
  --parallelism.tensor_parallel_degree=2 --parallelism.pipeline_parallel_degree=2 \
  --parallelism.pipeline_parallel_schedule=1F1B --debug.seed=42 --debug.deterministic \
  --metrics.enable-wandb

# Base trainer with validation (no PP)
NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \
  --training.steps 10 --parallelism.data_parallel_shard_degree=4 \
  --parallelism.tensor_parallel_degree=2 --debug.seed=42 --debug.deterministic \
  --validator.enable --validator.freq=5 --validator.steps=2 --metrics.enable-wandb

# Flux (DP=4 CP=2)
NGPU=8 LOG_RANK=0 ./run_train.sh --module flux --config flux_debugmodel \
  --training.steps 5 --debug.seed=42 --metrics.enable-wandb \
  --parallelism.data_parallel_shard_degree=4 --parallelism.context_parallel_degree=2 \
  --tokenizer.test_mode --encoder.random_init \
  --encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/ \
  --encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/ \
  --tokenizer.t5_tokenizer_path tests/assets/tokenizer \
  --tokenizer.clip_tokenizer_path tests/assets/tokenizer \
  --hf_assets_path tests/assets/tokenizer

# Forge (DP=2 TP=2 PP=2)
python -m torch.distributed.run --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 \
  -m torchtitan.experiments.forge.example_train --module llama3 --config llama3_debugmodel \
  --training.steps 5 --debug.seed=42 --debug.deterministic --metrics.enable-wandb \
  --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 \
  --parallelism.pipeline_parallel_degree=2 --parallelism.pipeline_parallel_schedule=1F1B

```

- [x] Integration: base trainer PP=2 DP=2 TP=2 loss matches main
- [x] Integration: base trainer with validation shows training + validation console lines
- [x] Integration: Flux DP=4 CP=2 loss matches main
- [x] Integration: Forge DP=2 TP=2 PP=2 loss matches main (real loss, not -2.0)
- [x] Integration: gantt (PP) has span types across all 8 ranks
- [x] Integration: gantt (validation) includes eval spans
- [x] Integration: checkpoint tag only appears on steps where checkpoint was saved

### Console output — Base trainer with PP (PP=2 DP=2 TP=2)

```
[rank0]:[titan] 2026-03-16 08:57:13,807 - torchtitan.observability.aggregation - INFO - step:  1  training/loss_mean: 8.028  training/grad_norm_max: 50810.5  trainer_memory/reserved_gib_max: 0.19  trainer_throughput/tps_mean: 361.4  trainer_throughput/tflops_mean: 0.026  trainer_throughput/mfu_pct_mean: 0.0026
[rank0]:[titan] 2026-03-16 08:57:15,323 - torchtitan.observability.aggregation - INFO - step:  5  training/loss_mean: 5.021  training/grad_norm_max: 81269.6  trainer_memory/reserved_gib_max: 0.20  trainer_throughput/tps_mean: 10177.9  trainer_throughput/tflops_mean: 0.73  trainer_throughput/mfu_pct_mean: 0.074
[rank0]:[titan] 2026-03-16 08:57:17,135 - torchtitan.observability.aggregation - INFO - step: 10  training/loss_mean: 3.90  training/grad_norm_max: 52873.5  trainer_memory/reserved_gib_max: 0.20  trainer_throughput/tps_mean: 11729.9  trainer_throughput/tflops_mean: 0.84  trainer_throughput/mfu_pct_mean: 0.085
```

### Console output — Base trainer with validation (DP=4 TP=2)

```
[rank0]:[titan] 2026-03-16 08:58:05,206 - torchtitan.observability.aggregation - INFO - step:  1  training/loss_mean: 8.13  training/grad_norm_max: 1.54  trainer_memory/reserved_gib_max: 0.64  trainer_throughput/tps_mean: 1409.4  trainer_throughput/tflops_mean: 0.10  trainer_throughput/mfu_pct_mean: 0.010
[rank0]:[titan] 2026-03-16 08:58:05,207 - torchtitan.observability.aggregation - INFO - validate step:  1  validation/loss_mean: 7.81  validator_memory/reserved_gib_max: 0.64  validator_throughput/tps_mean: 11233.7
[rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - step:  5  training/loss_mean: 5.32  training/grad_norm_max: 2.51  trainer_memory/reserved_gib_max: 0.67  trainer_throughput/tps_mean: 71439.5  trainer_throughput/tflops_mean: 5.11  trainer_throughput/mfu_pct_mean: 0.52
[rank0]:[titan] 2026-03-16 08:58:06,075 - torchtitan.observability.aggregation - INFO - validate step:  5  validation/loss_mean: 4.78  validator_memory/reserved_gib_max: 0.67  validator_throughput/tps_mean: 42884.5
[rank0]:[titan] 2026-03-16 08:58:07,045 - torchtitan.observability.aggregation - INFO - step: 10  training/loss_mean: 4.022  training/grad_norm_max: 1.88  trainer_memory/reserved_gib_max: 0.67  trainer_throughput/tps_mean: 66496.5  trainer_throughput/tflops_mean: 4.76  trainer_throughput/mfu_pct_mean: 0.48
[rank0]:[titan] 2026-03-16 08:58:07,046 - torchtitan.observability.aggregation - INFO - validate step: 10  validation/loss_mean: 4.034  validator_memory/reserved_gib_max: 0.67  validator_throughput/tps_mean: 45510.2
```

### WandB runs

| | Main | PR7c |
|---|------|------|
| Base (PP=2 DP=2 TP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/9ltdu69z | https://wandb.ai/cabernet-team/torchtitan/runs/uxcx428m |
| Base (validation) | — | https://wandb.ai/cabernet-team/torchtitan/runs/rsq5vxlx |
| Flux (DP=4 CP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/twfvo2ys | https://wandb.ai/cabernet-team/torchtitan/runs/iua7gyq5 |
| Forge (DP=2 TP=2 PP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/35ztpneg | https://wandb.ai/cabernet-team/torchtitan/runs/o1pqmpeb |





[ghstack-poisoned]
felipemello1 pushed a commit that referenced this pull request Mar 16, 2026
Copy link
Member

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd like to see timing on using NFS across 1, 2, 4, 8, N (as many as you can get) nodes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants