Observability: structured logging + training instrumentation (#3176)#3176
Observability: structured logging + training instrumentation (#3176)#3176felipemello1 wants to merge 1 commit intopytorch:mainfrom
Conversation
|
@felipemello1 has exported this pull request. If you are a Meta employee, you can view the originating Diff in D101049878. |
…orch#3176) Summary: **TLDR**: Enables structured-logging in torchtitan. Time spans, scalars and events can be logged, per rank, with metadata, to a jsonl or database, using python standard logger. This can then be converted into a gantt chart. It is cheap, runs on every step, and useful to find stragglers, high level bottlenecks, how asynchronous an RL run is, debug timeouts, weird behavior on some rank on some step, etc. **For details of the APIs, please read the added readme in** torchtitan/observability/structured_loggger/README.md. Examples: {F1989149827} ## Other comments To minimize the impact of lines changed in the trainer.py code: - We use decorators where we can, instead of context managers - We call tags and spans directly inside of some functions, e.g. ckpt, garbage collectors and profiler - **initialization**: We add it one time in the build call base class, so everything that calls build is automatically tracked. ## Is it safe? The workhorse here is logger.info. So the same way you treat logger.info, you should treat the structured logger. It can be disabled with a global flag. If the user forgets to setup `init_structured_logger`, we still call logger.info, but nothing is saved anywhere. It detects `is_compiling` and skips if true. ## Profiler default paths To align with expectations set by some internal tooling, we also changed the kineto and memory profilers default path. Differential Revision: D101049878
0d2beab to
98fe90d
Compare
…#3176) Summary: **TLDR**: Enables structured-logging in torchtitan. Time spans, scalars and events can be logged, per rank, with metadata, to a jsonl or database, using python standard logger. This can then be converted into a gantt chart. It is cheap, runs on every step, and useful to find stragglers, high level bottlenecks, how asynchronous an RL run is, debug timeouts, weird behavior on some rank on some step, etc. **For details of the APIs, please read the added readme in** torchtitan/observability/structured_loggger/README.md. Examples: {F1989149827} ## Other comments To minimize the impact of lines changed in the trainer.py code: - We use decorators where we can, instead of context managers - We call tags and spans directly inside of some functions, e.g. ckpt, garbage collectors and profiler - **initialization**: We add it one time in the build call base class, so everything that calls build is automatically tracked. ## Is it safe? The workhorse here is logger.info. So the same way you treat logger.info, you should treat the structured logger. It can be disabled with a global flag. If the user forgets to setup `init_structured_logger`, we still call logger.info, but nothing is saved anywhere. It detects `is_compiling` and skips if true. ## Profiler default paths To align with expectations set by some internal tooling, we also changed the kineto and memory profilers default path. Differential Revision: D101049878
98fe90d to
7c8f373
Compare
…#3176) Summary: **TLDR**: Enables structured-logging in torchtitan. Time spans, scalars and events can be logged, per rank, with metadata, to a jsonl or database, using python standard logger. This can then be converted into a gantt chart. It is cheap, runs on every step, and useful to find stragglers, high level bottlenecks, how asynchronous an RL run is, debug timeouts, weird behavior on some rank on some step, etc. **For details of the APIs, please read the added readme in** torchtitan/observability/structured_loggger/README.md. Examples: {F1989149827} ## Other comments To minimize the impact of lines changed in the trainer.py code: - We use decorators where we can, instead of context managers - We call tags and spans directly inside of some functions, e.g. ckpt, garbage collectors and profiler - **initialization**: We add it one time in the build call base class, so everything that calls build is automatically tracked. ## Is it safe? The workhorse here is logger.info. So the same way you treat logger.info, you should treat the structured logger. It can be disabled with a global flag. If the user forgets to setup `init_structured_logger`, we still call logger.info, but nothing is saved anywhere. It detects `is_compiling` and skips if true. ## Profiler default paths To align with expectations set by some internal tooling, we also changed the kineto and memory profilers default path. Differential Revision: D101049878
7c8f373 to
6f624b1
Compare
|
questions: this merge and do we by default expect this to runs on full ranks + every step? lts gonna be like O(1M) steps per 5T token on each rank. |
|
@rakkit, good questions!
If we are resuming, i assume it would be ok to have them together, no? But we could think about adding a run_id to the logs metadata, so users can filter by it. Another alternative is having the logs directed to a different dump_folder
We expect to have N rank files, yes. For the gantt, we could add flags so people can choose to only display N ranks, for example, or only the last N entries. But the main benefit here is being to query it with some database, e.g. if you have a nccl timeout, to find where/when it timedout, for example. On the logger side, we could cycle the file, limiting to N max lines, or breaking into new files every N entries. I thought that these refinements could be added as follow-up as users stress test it. Does it match your intuition? |
yes put them together is ok. just feel weird, or at least we should allow filtering by runs. on slurm clusters we usually get differnt nodes and usually when problem comes, we actually easily know and locate "which runs is borken" in that case checking the special runs log will be much easier and clean.
yeah it just we recently suffers a lot from scaling. everything on 1~2k gpu x GPFS is mess and i got PTSD on this |
makes sense! I can look into adding a 'run_id' metadata. But, as it is, it should be very easy for users to also add whatever handler/metadata they prefer (example in the readme)
Makes sense! If/when you give it a try, let me know. I will think a bit more about remediating this in a follow up as well. At first glance, do you see this type of logger being useful for your 1-2k gpu run? |
dumb solution is we take some magic hash or we broadcast some ID. speaking of that, would be help if we can also log info like, "i am [global-rank] 0, [fsdp-rank] 0, [dp-rank] 0, [cp-rank] ....." etc from each parallel dims.
yes i think it gonna be help for both small and large scale for debug and found problem. technically we could write an even short discussion of metadata structure and ask codex/claude to vide code some magic view for diagnose. for large scale the main concern is the log itsel, like it should not slow down training and be friendly to filesystem |
We log the global rank. I think that this can be done in postprocessing if we can map rank to parallel dim.
agreed. Internally I saw people doing it two ways:
The jsonl handler is a naive way of doing it, saving directly to shared FS. For large scale, i assume one would want to change their handlers if writing to FS slows down the run. I would need to check if python's logger.info is blocking or not in this case, but thats basically all we are doing: logger.info with extra metadata. |
There was a problem hiding this comment.
Did a quick browse, mostly looks good to me as we already discussed internally. I have one high-level question, which I asked in the checkpointer too. If we have structured logging, should we trim some logger.info, which try to record the time span of that action?
Another comment, we don't have to do this in this PR. But I think we should review some code blocks. If the indentation is too deep and the block is too long, and it is actually put in a time span, this code block may deserve a helper function and use a decorator style, such as model init.
Finally, for @rakkit's scaling question, do you think we can add an option to only log in certain ranks? For example, sl gets the rank by using dist.get_rank() and only rank % X == 0 does the logging. For OSS users, distributed filesystem is sometimes expensive unlike that in Meta internally.
| begin = time.monotonic() | ||
| logger.info("Saving the checkpoint (or staging if async is enabled).") |
There was a problem hiding this comment.
Do we still need this if we have the event logger?
| self._save_last_step(curr_step) | ||
| return | ||
| sl.add_step_tag("checkpoint_save") | ||
| with sl.log_trace_span("checkpoint_save"): |
There was a problem hiding this comment.
Is it bad that we actually use function decorator. I understand there will be a very small span for every step. Just don't know how bad if we actually do this.
Summary:
TLDR: Enables structured-logging in torchtitan. Time spans, scalars and events can be logged, per rank, with metadata, to a jsonl or database, using python standard logger. This can then be converted into a gantt chart. It is cheap, runs on every step, and useful to find stragglers, high level bottlenecks, how asynchronous an RL run is, debug timeouts, weird behavior on some rank on some step, etc.
For details of the APIs, please read the added readme in torchtitan/observability/structured_loggger/README.md.
Examples:

Other comments
To minimize the impact of lines changed in the trainer.py code:
Is it safe?
The workhorse here is logger.info. So the same way you treat logger.info, you should treat the structured logger.
It can be disabled with a global flag.
If the user forgets to setup
init_structured_logger, we still call logger.info, but nothing is saved anywhere.It detects
is_compilingand skips if true.Profiler default paths
To align with expectations set by some internal tooling, we also changed the kineto and memory profilers default path.
Test