Skip to content

Refactor: Recording and logging training and evaluation metrics in all trainers #1815

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

SurbhiJainUSC
Copy link
Collaborator

@SurbhiJainUSC SurbhiJainUSC commented Jun 10, 2025

Description

This PR removes writer from the list of values returned by setup_mesh_and_model. This simiplifies setup_mesh_and_model as it is unnecessary to create a TensorBoard summary writer object in that method. Additionally, this PR also refactors the metrics related code in all trainers (train.py, sft_trainer.py, grpo_trainer.py, elastic_train.py).

Thanks to @richjames0 for contributing!

Tests

E2E testing with train.py and sft_trainer.py - verified the metrics are correctly uploaded to TensorBoard

"""
steps=10 eval_interval=1 eval_steps=2 checkpoint_period=5
"""

# loading batch + training + checkpointing
completed step: 0, seconds: 27.429, TFLOP/s/device: 1.540, Tokens/s/device: 37.333, total_weights: 3573, loss: 0.993

# loading batch + training
completed step: 1, seconds: 0.010, TFLOP/s/device: 4187.050, Tokens/s/device: 101496.680, total_weights: 3253, loss: 0.993
eval metrics after step: 1, loss=0.961, total_weights=5459.0, step_time_seconds=20.147
completed step: 2, seconds: 0.012, TFLOP/s/device: 3606.211, Tokens/s/device: 87416.766, total_weights: 1692, loss: 0.912
eval metrics after step: 2, loss=0.975, total_weights=10918.0, step_time_seconds=19.760
completed step: 3, seconds: 0.009, TFLOP/s/device: 4453.210, Tokens/s/device: 107948.556, total_weights: 2944, loss: 1.105
eval metrics after step: 3, loss=0.967, total_weights=16377.0, step_time_seconds=21.060
completed step: 4, seconds: 0.009, TFLOP/s/device: 4500.176, Tokens/s/device: 109087.035, total_weights: 3466, loss: 0.948
eval metrics after step: 4, loss=0.958, total_weights=21836.0, step_time_seconds=20.085

# loading batch + training + checkpointing (checkpoint_period=5)
completed step: 5, seconds: 6.094, TFLOP/s/device: 6.932, Tokens/s/device: 168.029, total_weights: 2420, loss: 0.947
eval metrics after step: 5, loss=0.951, total_weights=27295.0, step_time_seconds=22.772

# loading batch + training
completed step: 6, seconds: 0.011, TFLOP/s/device: 3979.197, Tokens/s/device: 96458.176, total_weights: 1655, loss: 1.059
eval metrics after step: 6, loss=0.945, total_weights=32754.0, step_time_seconds=20.550
completed step: 7, seconds: 0.010, TFLOP/s/device: 4161.477, Tokens/s/device: 100876.761, total_weights: 3335, loss: 1.000
eval metrics after step: 7, loss=0.941, total_weights=38213.0, step_time_seconds=20.495
completed step: 8, seconds: 0.010, TFLOP/s/device: 4421.977, Tokens/s/device: 107191.458, total_weights: 1676, loss: 0.851
eval metrics after step: 8, loss=0.937, total_weights=43672.0, step_time_seconds=20.176
completed step: 9, seconds: 0.010, TFLOP/s/device: 4262.249, Tokens/s/device: 103319.544, total_weights: 1676, loss: 0.965
eval metrics after step: 9, loss=0.934, total_weights=49131.0, step_time_seconds=20.452

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@SurbhiJainUSC SurbhiJainUSC changed the title Refactor: Recording and logging training and evaluation metrics in al… Refactor: Recording and logging training and evaluation metrics in all trainers Jun 10, 2025
@SurbhiJainUSC SurbhiJainUSC marked this pull request as ready for review June 10, 2025 20:36

# Write train config params, num model params, and XLA flags to tensorboard
max_utils.add_text_to_summary_writer("num_model_parameters", str(num_model_parameters), writer)
max_utils.add_text_to_summary_writer("libtpu_init_args", os.environ["LIBTPU_INIT_ARGS"], writer)
maxtext_utils.add_config_to_summary_writer(config, writer)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In order to verify that the logs are identical to the version prior to refactoring, were you able to take a write out the logs and then diff?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I have verified the logs are same on Tensoboard before and after refactoring.

@@ -143,18 +131,7 @@ def train_loop(config, recorder, state=None):
last_profiling_step = prof.finished_initial_profile_step

example_batch = None
last_step_completion = datetime.datetime.now()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But previously last_step_completion was calculated here after write_setup_info_to_tensorboard work is done
am I reading this right?

in fact I think it should be calculated a bit later right before train step begins

Copy link
Collaborator Author

@SurbhiJainUSC SurbhiJainUSC Jun 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is used to calculate step_time_delta.
step_time_delta for step=n is (time when train step n is completed) - (time when train step n-1 is completed)

For step>1, calculating last_step_completion in metric_logger.record_train_metrics() makes sense because it is called just after training step is completed. However, it is unclear to me what should be the right spot for calculating last_step_completion for first step. Also, how do we interpret step_time_delta for step=0? What is your opinion @A9isha ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After looking into it more, I have moved the calculation for last_step_completion for the first step to just before we start the training run. This aligns with what you were suggesting previously.

Copy link
Collaborator Author

@SurbhiJainUSC SurbhiJainUSC Jun 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here are the training logs.

@SurbhiJainUSC SurbhiJainUSC force-pushed the summary_writer_refactor branch 2 times, most recently from fe62587 to fc920f4 Compare June 12, 2025 21:55
@SurbhiJainUSC SurbhiJainUSC requested a review from A9isha June 12, 2025 22:04
@SurbhiJainUSC SurbhiJainUSC force-pushed the summary_writer_refactor branch from fc920f4 to 1cab90d Compare June 18, 2025 19:42
@SurbhiJainUSC SurbhiJainUSC force-pushed the summary_writer_refactor branch from 1cab90d to a4e0fbc Compare June 18, 2025 19:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants