[Observability 6/7] Wire observability into Trainer, Flux, FT, and Forge#2606
[Observability 6/7] Wire observability into Trainer, Flux, FT, and Forge#2606felipemello1 wants to merge 2 commits intogh/felipemello1/22/basefrom
Conversation
[ghstack-poisoned]
| logger_instance.close() | ||
|
|
||
|
|
||
| def ensure_pp_loss_visible( |
There was a problem hiding this comment.
this and '_get_metrics_rank' are no longer needed. For PP loss, we would need the head of pp rank to log. Now, we only log on rank0.
Q: What if rank 0 is NOT the head of pp?
A: Thats ok, because we are reading the metrics from logged json files. As long as the correct rank logged, then rank 0 can parse it.
| return (world_size // pp_size) * (pp_size - 1) | ||
|
|
||
|
|
||
| class MetricsProcessor(Configurable): |
There was a problem hiding this comment.
moved to /observability. I was going to keep it here so its easier to compare/review, but i was having circular dependencies, so i just moved it to a new file.
|
|
||
| # Only record loss on ranks that computed it (last PP stage has | ||
| # the real loss; other PP stages have a -1/-2 sentinel). | ||
| if not parallel_dims.pp_enabled or self.pp_has_last_stage: |
There was a problem hiding this comment.
we only log it for the right ranks, since 'record_metric' is not aware about which ranks to use to reduce later on.
There was a problem hiding this comment.
if some ranks are "missing" what happens? How do you guard against the case where all ranks should log something but some ranks didn't?
There was a problem hiding this comment.
if some ranks are "missing" what happens?
The aggregation function is simple:
- read json from all ranks
- aggregate by metric_name and step
if rank 3 and 6 don't post the metric, they won't be part of the aggregation.
So this is enough:
if not parallel_dims.pp_enabled or self.pp_has_last_stage:
record_metric(....)
How do you guard against the case where all ranks should log something but some ranks didn't?
This would be a tricky bug. Currently:
- We put a barrier before .log() in the MP, to make sure that all metrics made it to the json before we call aggregate.
- We have a small time.sleep to make sure that the logger had time to flush.
- We always open the file (instead of keeping it open), and claude mentioned something about that helping with having a fresh cache.
So i think we are safe. We could maybe force logger.flush as well before aggregation.
There was a problem hiding this comment.
I think internal broadcast PP loss to all PP ranks. I would rather we do the same, instead of relying on "only saving to some ranks". But this is not urgent.
… FT, and Forge"
## How to review
Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same.
## Summary
Wires the observability library into TorchTitan's production Trainer and all derived trainers (Flux, FT, Forge).
[x] (This PR) 1st phase - Focuses on 'record_metrics' and excludes 'record_span' and 'record_event'.
[ ] (Next PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'.
The old `MetricsProcessor` in `components/metrics.py` is replaced by a new one in `observability/metrics_processor.py`. The old class (~327 lines) is deleted.
**Key improvements over the old MetricsProcessor:**
- **Real loss visible on rank 0 with PP.**: The old code showed sentinel loss (-2.0) on rank 0 because only the last PP stage computed loss. It also had 2 extra functions to assist with that (deleted).
- **Zero backend overhead on the hot path.** Logging to backends happens in background subprocess
- **Decoupled from the training loop.** Metrics are recorded via `record_metric(key, value)` at the call site. No more passing dicts of extra_metrics between functions.
- **Modularization**: Better separation between training and validation metrics.
**Trainers migrated:**
All four trainers follow the same pattern. Flux and FT inherit `batch_generator` from base Trainer. Forge is standalone.
The changes are the same in all four.
## Numerics
Loss and grad_norm match main exactly on all trainers with `--debug.seed=42 --debug.deterministic`:
| Trainer | Config | Step 1 loss | Step 5/10 loss |
|---------|--------|-------------|----------------|
| Base | PP=2 DP=2 TP=2 | 8.028 | 3.90 (step 10) |
| Flux | DP=4 CP=2 | 1.158 | 1.160 (step 5) |
| Forge | DP=2 TP=2 PP=2 | 8.028 | 5.19 (step 5) |
| FT | — | Not runnable (pre-existing bug on main: torchft renamed `ProcessGroupNCCL` → `ProcessGroupBabyNCCL`, `experiments/ft/manager.py:95` needs updating) |
## Test plan
```bash
# Base trainer (PP=2 DP=2 TP=2)
NGPU=8 LOG_RANK=0 ./run_train.sh --module llama3 --config llama3_debugmodel \
--training.steps 10 --parallelism.data_parallel_shard_degree=2 \
--parallelism.tensor_parallel_degree=2 --parallelism.pipeline_parallel_degree=2 \
--parallelism.pipeline_parallel_schedule=1F1B --debug.seed=42 --debug.deterministic \
--metrics.enable-wandb
# Flux (DP=4 CP=2)
NGPU=8 LOG_RANK=0 ./run_train.sh --module flux --config flux_debugmodel \
--training.steps 5 --debug.seed=42 --metrics.enable-wandb \
--parallelism.data_parallel_shard_degree=4 --parallelism.context_parallel_degree=2 \
--tokenizer.test_mode --encoder.random_init \
--encoder.clip_encoder tests/assets/flux_test_encoders/clip-vit-large-patch14/ \
--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/ \
--tokenizer.t5_tokenizer_path tests/assets/tokenizer \
--tokenizer.clip_tokenizer_path tests/assets/tokenizer \
--hf_assets_path tests/assets/tokenizer
# Forge (DP=2 TP=2 PP=2)
python -m torch.distributed.run --nproc_per_node=8 --local-ranks-filter 0 --role rank --tee 3 \
-m torchtitan.experiments.forge.example_train --module llama3 --config llama3_debugmodel \
--training.steps 5 --debug.seed=42 --debug.deterministic --metrics.enable-wandb \
--parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 \
--parallelism.pipeline_parallel_degree=2 --parallelism.pipeline_parallel_schedule=1F1B
```
- [x] pre-commit passes
- [x] Integration: base trainer PP=2 DP=2 TP=2 loss matches main
- [x] Integration: console on rank 0 shows real loss (not PP sentinel -2.0)
- [x] Integration: Flux DP=4 CP=2 loss matches main
- [x] Integration: Forge DP=2 TP=2 PP=2 loss matches main (shows real loss, not -2.0)
### Console output — before (main) vs after
**Main — Base trainer** (rank 0, PP=2 — shows sentinel loss -2.0):
```
[rank0]:[titan] 2026-03-16 08:53:41,731 - root - INFO - step: 1 loss: -2.00000 grad_norm: 50810.4961 memory: 0.19GiB(0.20%) tps: 421 tflops: 0.03 mfu: 0.00%
[rank0]:[titan] 2026-03-16 08:53:44,233 - root - INFO - step: 5 loss: -2.00000 grad_norm: 83210.2031 memory: 0.19GiB(0.20%) tps: 11,054 tflops: 0.79 mfu: 0.08%
```
**This PR — Base trainer** (rank 0 — real aggregated loss):
```
[rank0]:[titan] 2026-03-16 08:57:13,807 - torchtitan.observability.aggregation - INFO - step: 1 training/loss_mean: 8.028 training/grad_norm_max: 50810.5 trainer_memory/reserved_gib_max: 0.19 trainer_throughput/tps_mean: 361.4 trainer_throughput/tflops_mean: 0.026 trainer_throughput/mfu_pct_mean: 0.0026
[rank0]:[titan] 2026-03-16 08:57:17,135 - torchtitan.observability.aggregation - INFO - step: 10 training/loss_mean: 3.90 training/grad_norm_max: 52873.5 trainer_memory/reserved_gib_max: 0.20 trainer_throughput/tps_mean: 11729.9 trainer_throughput/tflops_mean: 0.84 trainer_throughput/mfu_pct_mean: 0.085
```
### Output folder
```
outputs/
├── system_logs/
│ └── trainer_rank_{0-7}_system.jsonl
├── experiment_logs/
│ └── trainer_rank_{0-7}_experiment.jsonl
├── tb/20260316-0857/
└── wandb/
```
### WandB runs
| | Main | This PR |
|---|------|------|
| Base (PP=2 DP=2 TP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/9ltdu69z | https://wandb.ai/cabernet-team/torchtitan/runs/uxcx428m |
| Flux (DP=4 CP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/twfvo2ys | https://wandb.ai/cabernet-team/torchtitan/runs/iua7gyq5 |
| Forge (DP=2 TP=2 PP=2) | https://wandb.ai/cabernet-team/torchtitan/runs/35ztpneg | https://wandb.ai/cabernet-team/torchtitan/runs/o1pqmpeb |
Trainer wandb for this PR
<img width="1674" height="986" alt="image" src="https://github.com/user-attachments/assets/dc9284e0-7657-43dc-bad5-2f8fc776c392" />
[ghstack-poisoned]
tianyu-l
left a comment
There was a problem hiding this comment.
Several things to clarify:
relationship and separation between MetricsProcessor / RolloutLogger and directly calling record_metrics -- Do we need an object only when something needs to be stateful for derived metrics? Do we need to start logging_worker process for both?
relationship between record_span, record_metrics, record_event. It seems
- record_span would log the span to record_metrics
- some metrics are also logged to record_event
(local loss, num_local_tokens) sounds a natural MeanMetric, but right now is not modeled that way.
relationship between MetricsProcessor and logger.info. It seems some console logging is handled by console_log_metric_keys but not all of them are
| ft_enable: bool = False, | ||
| ft_replica_id: int = 0, |
There was a problem hiding this comment.
main trainer shouldn't be aware of these fault tolerant fields
There was a problem hiding this comment.
I copied it from what we have today. Happy to change, but I dont have much context about ft. What do you have in mind?
torchtitan/torchtitan/components/metrics.py
Lines 327 to 328 in 0187d5f
There was a problem hiding this comment.
We can have extension points, but we shouldn't have them on core torchtitan. I don't mind we remove them entirely for now.
| record_metric( | ||
| "training/n_tokens_sum", NoOpMetric(value=global_ntokens_seen) | ||
| ) | ||
| record_event( |
There was a problem hiding this comment.
why they are both metrics and events?
There was a problem hiding this comment.
I'm sure chatgpt/claude can give us some inspiration :)
| 3. Background logging subprocess on rank 0 that reads per-rank experiment | ||
| JSONL, aggregates across ranks, and writes to WandB/TB/console. | ||
|
|
||
| The trainer calls methods in this order each step: |
There was a problem hiding this comment.
orders of record memory and record throughput should be interchangeable?
There was a problem hiding this comment.
i purposely put record throughput last so that any slow down from previous functions would be captured. But I assume that record_memory is mostly a O(1) lookup. Do you think its worth the edit?
| m.train() | ||
|
|
||
| # confirm that user will be able to view loss metrics on the console | ||
| ensure_pp_loss_visible( |
There was a problem hiding this comment.
yes, we should remove this
| # If data runs out during gradient accumulation, that | ||
| # entire step will not be executed. | ||
| raise DataloaderExhaustedError() from ex | ||
| with record_span("trainer_time/data_loading_s", EventType.FETCHING_BATCH): |
There was a problem hiding this comment.
what does _s mean? seconds?
There was a problem hiding this comment.
Yes. Happy to change the pattern to make it more explicit if you think its not intuitive.
There was a problem hiding this comment.
fine to have, but please make it consistent across all metrics -- I saw some metrics don't have the unit
|
|
||
| # Only record loss on ranks that computed it (last PP stage has | ||
| # the real loss; other PP stages have a -1/-2 sentinel). | ||
| if not parallel_dims.pp_enabled or self.pp_has_last_stage: |
There was a problem hiding this comment.
if some ranks are "missing" what happens? How do you guard against the case where all ranks should log something but some ranks didn't?
| ) = model_config.get_nparams_and_flops(model, config.training.seq_len) | ||
|
|
||
| color = utils.Color() | ||
| logger.info( |
There was a problem hiding this comment.
This is still directly logging to console right?
There was a problem hiding this comment.
If we intentionally pass some extra field, then our filters may act on them. But otherwise, it goes to the console normally.
This is not accurate? logger is not something we register handler to, so even if we put extra field, it won't be "acted on".




Stack from ghstack (oldest at bottom):
How to review
Understand the changes in trainer.py. The changes in other trainers (Flux, FT, Forge) are exactly the same.
Summary
Wires the observability library into TorchTitan's production Trainer and all derived trainers (Flux, FT, Forge).
[x] (This PR) 1st phase - Focuses on 'record_metrics' and excludes 'record_span' and 'record_event'.
[ ] (Next PR) 2nd phase - Adds 'record_span', 'add_step_tag' and 'record_event'.
The old
MetricsProcessorincomponents/metrics.pyis replaced by a new one inobservability/metrics_processor.py. The old class (~327 lines) is deleted.Key improvements over the old MetricsProcessor:
Real loss visible on rank 0 with PP.: The old code showed sentinel loss (-2.0) on rank 0 because only the last PP stage computed loss. It also had 2 extra functions to assist with that (deleted).
Zero backend overhead on the hot path. Logging to backends happens in background subprocess
Decoupled from the training loop. Metrics are recorded via
record_metric(key, value)at the call site. No more passing dicts of extra_metrics between functions.Modularization: Better separation between training and validation metrics.
Trainers migrated:
All four trainers follow the same pattern. Flux and FT inherit
batch_generatorfrom base Trainer. Forge is standalone.The changes are the same in all four.
Numerics
Loss and grad_norm match main exactly on all trainers with
--debug.seed=42 --debug.deterministic:ProcessGroupNCCL→ProcessGroupBabyNCCL,experiments/ft/manager.py:95needs updating)Test plan
Console output — before (main) vs after
Main — Base trainer (rank 0, PP=2 — shows sentinel loss -2.0):
This PR — Base trainer (rank 0 — real aggregated loss):
Output folder
WandB runs
Trainer wandb for this PR
