From 7c55f07f3624672ac45521ee2faf23da83e2eb20 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Dec 2024 14:17:36 +0100 Subject: [PATCH 01/28] Refactor batch loss and grad calculation --- torchtitan/train.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 14edd70ad4..45ec549a1a 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -316,11 +316,7 @@ def batch_generator( yield input_dict, labels - def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): - self.optimizers.zero_grad() - - # Keep these variables local to shorten the code as these are - # the major variables that are used in the training loop. + def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): model_parts = self.model_parts world_mesh = self.world_mesh parallel_dims = self.parallel_dims @@ -371,6 +367,18 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): # need to free to before bwd to avoid peaking memory del pred loss.backward() + return loss + + def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + loss = self.batch_backward(input_dict, labels) dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], From 83de5dc005b143ba28fcd005b8bc3806ca4ad5d5 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Dec 2024 14:18:43 +0100 Subject: [PATCH 02/28] Support gradient accumulation Fix #292. --- torchtitan/components/metrics.py | 2 + torchtitan/config_manager.py | 5 +++ torchtitan/train.py | 72 ++++++++++++++++++++++++++++---- 3 files changed, 71 insertions(+), 8 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 084c2c4ffe..73ec5c4af5 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -308,6 +308,7 @@ class MetricsProcessor: gpu_peak_flops: int ntokens_since_last_log: int data_loading_times: list[float] + accumulated_losses: list[torch.Tensor] time_last_log: float num_flops_per_token: int @@ -336,6 +337,7 @@ def __init__( ) self.ntokens_since_last_log = 0 self.data_loading_times = [] + self.accumulated_losses = [] self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8299f91a9..98d18d7313 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -192,6 +192,11 @@ class Training: batch_size: int = 8 """Batch size""" + global_batch_size: int | None = None + """ + Global batch size (defaults to `training.batch_size * data-parallel degree`) + """ + seq_len: int = 2048 """Sequence length""" diff --git a/torchtitan/train.py b/torchtitan/train.py index 45ec549a1a..c0f6c3f9ee 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools import importlib import os import time @@ -118,6 +119,27 @@ def __init__(self, job_config: JobConfig): ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) + # verify batch sizes + if job_config.training.global_batch_size is None: + job_config.training.global_batch_size = \ + job_config.training.batch_size * dp_degree + assert job_config.training.global_batch_size > 0 + assert ( + job_config.training.global_batch_size + % (job_config.training.batch_size * dp_degree) + == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({job_config.training.global_batch_size} " + f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" + ) + + self.gradient_accumulation_steps = ( + job_config.training.global_batch_size + // (job_config.training.batch_size * dp_degree) + ) + assert self.gradient_accumulation_steps > 0 + # build dataloader tokenizer = ( self.train_spec.build_tokenizer_fn(job_config) @@ -183,6 +205,15 @@ def __init__(self, job_config: JobConfig): self.loss_fn = self.train_spec.build_loss_fn(job_config) + unwrapped_loss_fn = self.loss_fn + + @functools.wraps(unwrapped_loss_fn) + def accumulated_loss_fn(*args, **kwargs): + loss = unwrapped_loss_fn(*args, **kwargs) + return loss / self.gradient_accumulation_steps + + self.loss_fn = accumulated_loss_fn + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -288,7 +319,8 @@ def __init__(self, job_config: JobConfig): logger.info( "Trainer is initialized with " f"local batch size {job_config.training.batch_size}, " - f"global batch size {job_config.training.batch_size * dp_degree}, " + f"global batch size {job_config.training.global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " f"(warmup {job_config.lr_scheduler.warmup_steps})." @@ -369,7 +401,14 @@ def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tens loss.backward() return loss - def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + def train_step( + self, + data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], + ) -> bool | None: + """ + Execute a training step and return whether the data loader ran + out of data. + """ self.optimizers.zero_grad() # Keep these variables local to shorten the code as these are @@ -378,7 +417,15 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): world_mesh = self.world_mesh parallel_dims = self.parallel_dims - loss = self.batch_backward(input_dict, labels) + for microbatch in range(self.gradient_accumulation_steps): + try: + input_dict, labels = next(data_iterator) + except StopIteration: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + return True + loss = self.batch_backward(input_dict, labels) + self.metrics_processor.accumulated_losses.append(loss.detach()) dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], @@ -390,6 +437,10 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): self.optimizers.step() self.lr_schedulers.step() + # Reduce the data collected over gradient accumulation steps. + loss = torch.sum(torch.stack(self.metrics_processor.accumulated_losses)) + self.metrics_processor.accumulated_losses.clear() + # log metrics if not self.metrics_processor.should_log(self.step): return @@ -436,14 +487,19 @@ def train(self): sync_every=job_config.fault_tolerance.sync_steps, ), ): - for inputs, labels in self.batch_generator(self.dataloader): - if self.step >= job_config.training.steps: - break + data_iterator = self.batch_generator(self.dataloader) + data_ran_out = False + while self.step < job_config.training.steps and not data_ran_out: self.step += 1 self.gc_handler.run(self.step) - self.train_step(inputs, labels) + data_ran_out = self.train_step(data_iterator) + if data_ran_out: + logger.info( + "Ran out of data; last step was canceled. " + "Saving final checkpoint and exiting." + ) self.checkpointer.save( - self.step, force=(self.step == job_config.training.steps) + self.step, force=(self.step == job_config.training.steps or data_ran_out) ) # signal the profiler that the next profiling step has started From c9396c2c8624f467deb7338bd0d98945f7aeabf1 Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 29 May 2025 14:52:07 +0200 Subject: [PATCH 03/28] Run pre-commit hooks --- torchtitan/train.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index c0f6c3f9ee..1dc326466f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -121,8 +121,9 @@ def __init__(self, job_config: JobConfig): # verify batch sizes if job_config.training.global_batch_size is None: - job_config.training.global_batch_size = \ + job_config.training.global_batch_size = ( job_config.training.batch_size * dp_degree + ) assert job_config.training.global_batch_size > 0 assert ( job_config.training.global_batch_size @@ -134,9 +135,8 @@ def __init__(self, job_config: JobConfig): f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" ) - self.gradient_accumulation_steps = ( - job_config.training.global_batch_size - // (job_config.training.batch_size * dp_degree) + self.gradient_accumulation_steps = job_config.training.global_batch_size // ( + job_config.training.batch_size * dp_degree ) assert self.gradient_accumulation_steps > 0 @@ -402,8 +402,8 @@ def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tens return loss def train_step( - self, - data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], + self, + data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], ) -> bool | None: """ Execute a training step and return whether the data loader ran @@ -499,7 +499,8 @@ def train(self): "Saving final checkpoint and exiting." ) self.checkpointer.save( - self.step, force=(self.step == job_config.training.steps or data_ran_out) + self.step, + force=(self.step == job_config.training.steps or data_ran_out), ) # signal the profiler that the next profiling step has started From 1da4106da30afb2aeff483ac20e1f8c8d2ea9efb Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 29 May 2025 20:33:00 +0200 Subject: [PATCH 04/28] Change `global_batch_size` type to `int` Previously `int | None`. Makes it possible to obtain the automatic calculation of it when it has already been set in a TOML config. --- torchtitan/config_manager.py | 2 +- torchtitan/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 98d18d7313..9e9c2387b7 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -192,7 +192,7 @@ class Training: batch_size: int = 8 """Batch size""" - global_batch_size: int | None = None + global_batch_size: int = -1 """ Global batch size (defaults to `training.batch_size * data-parallel degree`) """ diff --git a/torchtitan/train.py b/torchtitan/train.py index 1dc326466f..edc2a1919e 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -120,7 +120,7 @@ def __init__(self, job_config: JobConfig): self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # verify batch sizes - if job_config.training.global_batch_size is None: + if job_config.training.global_batch_size < 0: job_config.training.global_batch_size = ( job_config.training.batch_size * dp_degree ) From 82b5e5dcefce9994aa4447375d7116c0a1f3042d Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 14:53:26 +0200 Subject: [PATCH 05/28] Do not save checkpoint when data ran out @fegin said: > TorchTitan currently doesn't perform force checkpoint if data is > depleted. We can fix this but I suggest that we don't do this in this > PR. (See https://github.com/pytorch/torchtitan/pull/1238#discussion_r2115249675.) --- torchtitan/train.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index edc2a1919e..bd382d45bf 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -494,14 +494,9 @@ def train(self): self.gc_handler.run(self.step) data_ran_out = self.train_step(data_iterator) if data_ran_out: - logger.info( - "Ran out of data; last step was canceled. " - "Saving final checkpoint and exiting." - ) - self.checkpointer.save( - self.step, - force=(self.step == job_config.training.steps or data_ran_out), - ) + logger.info("Ran out of data; last step was canceled.") + break + self.checkpointer.save(self.step, force=(self.step == job_config.training.steps)) # signal the profiler that the next profiling step has started if torch_profiler: From 4a30b743abc61b9d2b8d1d7e00c583990a080431 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 14:32:35 +0200 Subject: [PATCH 06/28] Raise custom exception upon data depletion I.e., a new `DataloaderStopIteration` that inherits from `StopIteration`. Accordingly, no longer return an optional `bool` to indicate depletion and adapt the remainder of the code to catch the new exception instead. --- torchtitan/train.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index bd382d45bf..a38d2c4ed7 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -32,6 +32,11 @@ ) +class DataloaderStopIteration(StopIteration): + """An exception that indicates dataloader exhaustion.""" + pass + + class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig gc_handler: utils.GarbageCollection @@ -348,6 +353,17 @@ def batch_generator( yield input_dict, labels + def next_batch( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + try: + input_dict, labels = next(data_iterator) + except StopIteration as ex: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + raise DataloaderStopIteration() from ex + return input_dict, labels + def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): model_parts = self.model_parts world_mesh = self.world_mesh @@ -401,14 +417,7 @@ def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tens loss.backward() return loss - def train_step( - self, - data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]], - ) -> bool | None: - """ - Execute a training step and return whether the data loader ran - out of data. - """ + def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]): self.optimizers.zero_grad() # Keep these variables local to shorten the code as these are @@ -418,12 +427,7 @@ def train_step( parallel_dims = self.parallel_dims for microbatch in range(self.gradient_accumulation_steps): - try: - input_dict, labels = next(data_iterator) - except StopIteration: - # If data runs out during gradient accumulation, that - # entire step will not be executed. - return True + input_dict, labels = self.next_batch(data_iterator) loss = self.batch_backward(input_dict, labels) self.metrics_processor.accumulated_losses.append(loss.detach()) @@ -488,12 +492,12 @@ def train(self): ), ): data_iterator = self.batch_generator(self.dataloader) - data_ran_out = False - while self.step < job_config.training.steps and not data_ran_out: + while self.step < job_config.training.steps: self.step += 1 self.gc_handler.run(self.step) - data_ran_out = self.train_step(data_iterator) - if data_ran_out: + try: + self.train_step(data_iterator) + except DataloaderStopIteration: logger.info("Ran out of data; last step was canceled.") break self.checkpointer.save(self.step, force=(self.step == job_config.training.steps)) From 6cd2e318d0d195aca81419bc762de983436725b1 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 14:34:08 +0200 Subject: [PATCH 07/28] Rename "batch size" to "local batch size" This concerns only renaming - `--training.batch_size` to `--training.local_batch_size` and - `job_config.training.batch_size` to `job_config.training.local_batch_size`. --- docs/converging.md | 4 ++-- scripts/estimate/estimation.py | 4 ++-- tests/unit_tests/test_dataset_checkpointing.py | 2 +- torchtitan/config_manager.py | 8 ++++---- torchtitan/datasets/hf_datasets.py | 2 +- torchtitan/distributed/pipeline.py | 4 ++-- .../deepseek_v3/train_configs/deepseek_v2.toml | 2 +- torchtitan/experiments/deepseek_v3/train_ds_real.py | 2 +- torchtitan/experiments/flux/dataset/flux_dataset.py | 2 +- .../flux/tests/unit_tests/test_flux_dataloader.py | 2 +- .../experiments/flux/train_configs/debug_model.toml | 2 +- .../experiments/flux/train_configs/flux_dev_model.toml | 2 +- .../flux/train_configs/flux_schnell_model.toml | 2 +- .../experiments/llama4/train_configs/debug_model.toml | 2 +- .../llama4/train_configs/llama4_17bx128e.toml | 2 +- .../llama4/train_configs/llama4_17bx16e.toml | 2 +- torchtitan/experiments/multimodal/check_padding_mm.py | 2 +- torchtitan/experiments/multimodal/mm_dataset.py | 2 +- .../models/llama3/train_configs/debug_model.toml | 2 +- .../models/llama3/train_configs/llama3_405b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_70b.toml | 2 +- torchtitan/models/llama3/train_configs/llama3_8b.toml | 2 +- torchtitan/train.py | 10 +++++----- 23 files changed, 33 insertions(+), 33 deletions(-) diff --git a/docs/converging.md b/docs/converging.md index 2ba1575bda..3bb62d9353 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -12,7 +12,7 @@ This note clarifies the recommended practices to follow when testing the loss co ## Guidelines -To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.batch_size`) in the toml config, and at the same time fix the data parallel degree. +To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.local_batch_size`) in the toml config, and at the same time fix the data parallel degree. If the technique is a parallelism (TP/PP/CP/etc) - The control set is a 1D FSDP job on `dp` GPUs (or any other verified setups), with a trusted training config (e.g. those under train_configs). @@ -40,7 +40,7 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor ### Setup - Base config: [torchtitan/models/llama3/train_configs/llama3_8b.toml](../torchtitan/models/llama3/train_configs/llama3_8b.toml) -- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` +- `training.local_batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` - `training.data_parallel_shard_degree = 8`, resulting in global batch size 32 - `training.steps = 3000`, `lr_scheduler.warmup_steps = 600` diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index bef5ad6970..85a058a0fb 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -130,13 +130,13 @@ def estimate_memory(job_config: JobConfig): torch.randint( 0, model_args.vocab_size, - (job_config.training.batch_size, model_args.max_seq_len), + (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), torch.randint( 0, model_args.vocab_size, - (job_config.training.batch_size, model_args.max_seq_len), + (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), ) diff --git a/tests/unit_tests/test_dataset_checkpointing.py b/tests/unit_tests/test_dataset_checkpointing.py index 1eae8f1962..00998fc495 100644 --- a/tests/unit_tests/test_dataset_checkpointing.py +++ b/tests/unit_tests/test_dataset_checkpointing.py @@ -64,7 +64,7 @@ def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank) [ "--training.dataset", dataset_name, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seq_len", str(seq_len), diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 9e9c2387b7..eb34949ea6 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -189,12 +189,12 @@ class Training: loaded from this path instead of downloaded. """ - batch_size: int = 8 - """Batch size""" + local_batch_size: int = 8 + """Local batch size (i.e., per-device batch size)""" global_batch_size: int = -1 """ - Global batch size (defaults to `training.batch_size * data-parallel degree`) + Global batch size (defaults to `training.local_batch_size * data-parallel degree`) """ seq_len: int = 2048 @@ -338,7 +338,7 @@ class Parallelism: pipeline_parallel_microbatch_size: int = 1 """ The size of each pipeline parallel microbatch (default 1). - This value is used to compute the total number of microbatches by dividing batch_size with + This value is used to compute the total number of microbatches by dividing local batch_size with pipeline_parallel_microbatch_size. The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 350feac695..023b4a29e9 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -174,7 +174,7 @@ def build_hf_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len hf_ds = HuggingFaceDataset( diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index 973c58c19c..366021a7fc 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -150,11 +150,11 @@ def build_pipeline_schedule( looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training if batch_size % microbatch_size != 0: raise ValueError( - f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. " + f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. " "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." ) n_microbatches = batch_size // microbatch_size diff --git a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml index efd54c48f8..1402e57e1e 100644 --- a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml +++ b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml @@ -38,7 +38,7 @@ decay_type = "linear" lr_min = 0.1 [training] -batch_size = 2 # 8 +local_batch_size = 2 # 8 seq_len = 1024 # 2048 max_norm = 1.0 # grad norm clipping steps = 200 diff --git a/torchtitan/experiments/deepseek_v3/train_ds_real.py b/torchtitan/experiments/deepseek_v3/train_ds_real.py index c1bbc4ff0a..398360a6ee 100644 --- a/torchtitan/experiments/deepseek_v3/train_ds_real.py +++ b/torchtitan/experiments/deepseek_v3/train_ds_real.py @@ -145,7 +145,7 @@ def run_full_model( # model.setup_symm_mem(torch.bfloat16, device) torch.manual_seed(ep_rank) - bs = config.training.batch_size # * microbatches # 4 + bs = config.training.local_batch_size # * microbatches # 4 seqlen = config.training.seq_len # 128 # metrics manager diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index 7e1de9dcea..9dcd641498 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -278,7 +278,7 @@ def build_flux_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config) diff --git a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py index 8a354eca67..8072032115 100644 --- a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py +++ b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py @@ -40,7 +40,7 @@ def _test_flux_dataloader(self, dataset_name): str(256), "--training.dataset", dataset_name, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seed", "0", diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index 3bc4a4e236..596a513737 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -33,7 +33,7 @@ warmup_steps = 1 # 10% warmup steps decay_ratio = 0.0 # no decay, stay stable during training [training] -batch_size = 4 +local_batch_size = 4 max_norm = 2.0 # grad norm clipping steps = 10 compile = false diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 72e73e10b0..23dbb2b429 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.0 # no decay [training] -batch_size = 32 +local_batch_size = 32 max_norm = 1.0 # grad norm clipping steps = 30_000 compile = false diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 2e66170ddf..1e2421e59a 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.0 # no decay [training] -batch_size = 64 +local_batch_size = 64 max_norm = 1.0 # grad norm clipping steps = 30_000 compile = false diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bb038e60c4..bc48d38090 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -37,7 +37,7 @@ decay_type = "linear" lr_min = 0.1 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index dfaa4114e0..f508968c8d 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -30,7 +30,7 @@ warmup_steps = 600 lr_min = 0.1 [training] -batch_size = 1 +local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 41fc37efde..c899dd5087 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -30,7 +30,7 @@ warmup_steps = 600 lr_min = 0.1 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py index cc8b357286..7534dcdccf 100644 --- a/torchtitan/experiments/multimodal/check_padding_mm.py +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -35,7 +35,7 @@ def main( [ "--training.dataset", dataset, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seq_len", str(seq_len), diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py index a29627aace..519272c74e 100644 --- a/torchtitan/experiments/multimodal/mm_dataset.py +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -240,7 +240,7 @@ def build_mm_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 60916966cb..4d4426d0d0 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -39,7 +39,7 @@ decay_type = "linear" lr_min = 0.0 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 6adcd3d90a..175c41fc8f 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps [training] -batch_size = 2 +local_batch_size = 2 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 582f5aff5c..257205c314 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] -batch_size = 8 +local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 6cfe61bc70..aa6f81f803 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 200 # lr scheduler warm up [training] -batch_size = 1 +local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 diff --git a/torchtitan/train.py b/torchtitan/train.py index a38d2c4ed7..50c962a0c3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -127,21 +127,21 @@ def __init__(self, job_config: JobConfig): # verify batch sizes if job_config.training.global_batch_size < 0: job_config.training.global_batch_size = ( - job_config.training.batch_size * dp_degree + job_config.training.local_batch_size * dp_degree ) assert job_config.training.global_batch_size > 0 assert ( job_config.training.global_batch_size - % (job_config.training.batch_size * dp_degree) + % (job_config.training.local_batch_size * dp_degree) == 0 ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({job_config.training.global_batch_size} " - f"% ({job_config.training.batch_size} * {dp_degree}) != 0)" + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" ) self.gradient_accumulation_steps = job_config.training.global_batch_size // ( - job_config.training.batch_size * dp_degree + job_config.training.local_batch_size * dp_degree ) assert self.gradient_accumulation_steps > 0 @@ -323,7 +323,7 @@ def accumulated_loss_fn(*args, **kwargs): logger.info( "Trainer is initialized with " - f"local batch size {job_config.training.batch_size}, " + f"local batch size {job_config.training.local_batch_size}, " f"global batch size {job_config.training.global_batch_size}, " f"gradient accumulation steps {self.gradient_accumulation_steps}, " f"sequence length {job_config.training.seq_len}, " From 45fbe0a3e213ccfed697647d0339c2448645c6f6 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 14:43:04 +0200 Subject: [PATCH 08/28] Rename `batch_backward` to `forward_backward_step` I.e., the method in `Trainer`. --- torchtitan/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 50c962a0c3..fb310d0f8d 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -364,7 +364,7 @@ def next_batch( raise DataloaderStopIteration() from ex return input_dict, labels - def batch_backward(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): model_parts = self.model_parts world_mesh = self.world_mesh parallel_dims = self.parallel_dims @@ -428,7 +428,7 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc for microbatch in range(self.gradient_accumulation_steps): input_dict, labels = self.next_batch(data_iterator) - loss = self.batch_backward(input_dict, labels) + loss = self.forward_backward_step(input_dict, labels) self.metrics_processor.accumulated_losses.append(loss.detach()) dist_utils.clip_grad_norm_( From be231949131aa7c6b95a4a33acb37673891e8c46 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 14:48:34 +0200 Subject: [PATCH 09/28] Refactor loss function gradient accumulation wrap --- torchtitan/components/loss.py | 10 ++++++++++ torchtitan/train.py | 14 ++++---------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index e2df56d52e..5f28ad1d3a 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from typing import Callable, TypeAlias import torch @@ -27,3 +28,12 @@ def build_cross_entropy_loss(job_config: JobConfig): logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn + + +def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps): + @functools.wraps(unwrapped_loss_fn) + def accumulated_loss_fn(*args, **kwargs): + loss = unwrapped_loss_fn(*args, **kwargs) + return loss / accumulation_steps + + return accumulated_loss_fn diff --git a/torchtitan/train.py b/torchtitan/train.py index fb310d0f8d..248ca79bcf 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import functools import importlib import os import time @@ -17,6 +16,7 @@ import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, ensure_pp_loss_visible, @@ -209,15 +209,9 @@ def __init__(self, job_config: JobConfig): buffer_device = None self.loss_fn = self.train_spec.build_loss_fn(job_config) - - unwrapped_loss_fn = self.loss_fn - - @functools.wraps(unwrapped_loss_fn) - def accumulated_loss_fn(*args, **kwargs): - loss = unwrapped_loss_fn(*args, **kwargs) - return loss / self.gradient_accumulation_steps - - self.loss_fn = accumulated_loss_fn + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) # apply parallelisms and initialization if parallel_dims.pp_enabled: From 5ae21d72a4136e855224ebaadc983eeed04cf355 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:06:36 +0200 Subject: [PATCH 10/28] Do not modify `job_config.global_batch_size` Instead use a new helper variable `global_batch_size` for all logic. Improves readability. --- torchtitan/train.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 248ca79bcf..cae3b56064 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -125,22 +125,17 @@ def __init__(self, job_config: JobConfig): self.train_spec = train_spec_module.get_train_spec(job_config.model.name) # verify batch sizes - if job_config.training.global_batch_size < 0: - job_config.training.global_batch_size = ( - job_config.training.local_batch_size * dp_degree - ) - assert job_config.training.global_batch_size > 0 - assert ( - job_config.training.global_batch_size - % (job_config.training.local_batch_size * dp_degree) - == 0 - ), ( + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0, ( f"global batch size must be multiple of local batch size times " - f"data-parallel degree ({job_config.training.global_batch_size} " + f"data-parallel degree ({global_batch_size} " f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" ) - self.gradient_accumulation_steps = job_config.training.global_batch_size // ( + self.gradient_accumulation_steps = global_batch_size // ( job_config.training.local_batch_size * dp_degree ) assert self.gradient_accumulation_steps > 0 @@ -318,7 +313,7 @@ def __init__(self, job_config: JobConfig): logger.info( "Trainer is initialized with " f"local batch size {job_config.training.local_batch_size}, " - f"global batch size {job_config.training.global_batch_size}, " + f"global batch size {global_batch_size}, " f"gradient accumulation steps {self.gradient_accumulation_steps}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " From 8cf71d6349752eaf155ba84e98d9eaec7bf91c57 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:08:48 +0200 Subject: [PATCH 11/28] Add comment on default gradient accumulation step --- torchtitan/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtitan/train.py b/torchtitan/train.py index cae3b56064..07a612d365 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -127,6 +127,8 @@ def __init__(self, job_config: JobConfig): # verify batch sizes global_batch_size = job_config.training.global_batch_size if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. global_batch_size = job_config.training.local_batch_size * dp_degree assert global_batch_size > 0 assert global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0, ( From b4850127fef91f3c4bd62827c01184283256c459 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:15:39 +0200 Subject: [PATCH 12/28] Move gradient accumulation derivation logic Improve readability. --- torchtitan/train.py | 37 +++++++++++++++++++------------------ 1 file changed, 19 insertions(+), 18 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 07a612d365..7d21a386e2 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -124,24 +124,6 @@ def __init__(self, job_config: JobConfig): ) self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - # verify batch sizes - global_batch_size = job_config.training.global_batch_size - if global_batch_size < 0: - # This global batch size results in 1 gradient accumulation - # step. - global_batch_size = job_config.training.local_batch_size * dp_degree - assert global_batch_size > 0 - assert global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0, ( - f"global batch size must be multiple of local batch size times " - f"data-parallel degree ({global_batch_size} " - f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" - ) - - self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree - ) - assert self.gradient_accumulation_steps > 0 - # build dataloader tokenizer = ( self.train_spec.build_tokenizer_fn(job_config) @@ -206,6 +188,25 @@ def __init__(self, job_config: JobConfig): buffer_device = None self.loss_fn = self.train_spec.build_loss_fn(job_config) + + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0, ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( self.loss_fn, self.gradient_accumulation_steps ) From 12d274d4997557f45a35dfc1783379070a01a744 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:20:47 +0200 Subject: [PATCH 13/28] Remove redundant shortcut variables These were only used in 1 or 2 locations each. --- torchtitan/train.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 7d21a386e2..8aba953916 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -358,7 +358,6 @@ def next_batch( def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): model_parts = self.model_parts - world_mesh = self.world_mesh parallel_dims = self.parallel_dims # apply context parallelism if cp is enabled @@ -366,7 +365,7 @@ def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: tor inputs = input_dict["input"] optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], + cp_mesh=self.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -414,8 +413,6 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. - model_parts = self.model_parts - world_mesh = self.world_mesh parallel_dims = self.parallel_dims for microbatch in range(self.gradient_accumulation_steps): @@ -424,7 +421,7 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc self.metrics_processor.accumulated_losses.append(loss.detach()) dist_utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], + [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, @@ -455,8 +452,8 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), ) else: global_avg_loss = global_max_loss = loss.detach().item() From d4dd1226e5231ca86c98f220c02dc5c3b82ce2df Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:21:03 +0200 Subject: [PATCH 14/28] Improve readability --- torchtitan/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/train.py b/torchtitan/train.py index 8aba953916..ed41d9ace4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -406,6 +406,7 @@ def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: tor # need to free to before bwd to avoid peaking memory del pred loss.backward() + return loss def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]): From a7f4c8066737fd869e8659bbd124ff6a7d2a8ac2 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:28:46 +0200 Subject: [PATCH 15/28] Add `gradient_accumulation_step` to dataclass --- torchtitan/train.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchtitan/train.py b/torchtitan/train.py index ed41d9ace4..86bea339d9 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -44,6 +44,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh + gradient_accumulation_steps: int dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor From 7003eb9f63d0dd8cfce88214ea33be06a0c419e3 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:30:37 +0200 Subject: [PATCH 16/28] Move `accumulated_losses` to `Trainer` ... from `MetricsProcessor`. --- torchtitan/components/metrics.py | 2 -- torchtitan/train.py | 8 +++++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 73ec5c4af5..084c2c4ffe 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -308,7 +308,6 @@ class MetricsProcessor: gpu_peak_flops: int ntokens_since_last_log: int data_loading_times: list[float] - accumulated_losses: list[torch.Tensor] time_last_log: float num_flops_per_token: int @@ -337,7 +336,6 @@ def __init__( ) self.ntokens_since_last_log = 0 self.data_loading_times = [] - self.accumulated_losses = [] self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() diff --git a/torchtitan/train.py b/torchtitan/train.py index 86bea339d9..96ff72fd6c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -45,6 +45,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh gradient_accumulation_steps: int + accumulated_losses: list[torch.Tensor] dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor @@ -211,6 +212,7 @@ def __init__(self, job_config: JobConfig): self.loss_fn = rescale_accumulated_loss( self.loss_fn, self.gradient_accumulation_steps ) + self.accumulated_losses = [] # apply parallelisms and initialization if parallel_dims.pp_enabled: @@ -420,7 +422,7 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc for microbatch in range(self.gradient_accumulation_steps): input_dict, labels = self.next_batch(data_iterator) loss = self.forward_backward_step(input_dict, labels) - self.metrics_processor.accumulated_losses.append(loss.detach()) + self.accumulated_losses.append(loss.detach()) dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], @@ -433,8 +435,8 @@ def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torc self.lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. - loss = torch.sum(torch.stack(self.metrics_processor.accumulated_losses)) - self.metrics_processor.accumulated_losses.clear() + loss = torch.sum(torch.stack(self.accumulated_losses)) + self.accumulated_losses.clear() # log metrics if not self.metrics_processor.should_log(self.step): From 266cffa06c9d37400f75c5ccf39f727a711db5ae Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:33:59 +0200 Subject: [PATCH 17/28] Apply pre-commit hooks --- torchtitan/train.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 96ff72fd6c..37549ab968 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -34,6 +34,7 @@ class DataloaderStopIteration(StopIteration): """An exception that indicates dataloader exhaustion.""" + pass @@ -198,7 +199,9 @@ def __init__(self, job_config: JobConfig): # step. global_batch_size = job_config.training.local_batch_size * dp_degree assert global_batch_size > 0 - assert global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0, ( + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({global_batch_size} " f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" @@ -359,7 +362,9 @@ def next_batch( raise DataloaderStopIteration() from ex return input_dict, labels - def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ): model_parts = self.model_parts parallel_dims = self.parallel_dims @@ -412,7 +417,9 @@ def forward_backward_step(self, input_dict: dict[str, torch.Tensor], labels: tor return loss - def train_step(self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]): + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): self.optimizers.zero_grad() # Keep these variables local to shorten the code as these are @@ -493,7 +500,9 @@ def train(self): except DataloaderStopIteration: logger.info("Ran out of data; last step was canceled.") break - self.checkpointer.save(self.step, force=(self.step == job_config.training.steps)) + self.checkpointer.save( + self.step, force=(self.step == job_config.training.steps) + ) # signal the profiler that the next profiling step has started if torch_profiler: From 072b9b4e669f080df2b6f1a423c8cfcb77290267 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 3 Jun 2025 15:41:23 +0200 Subject: [PATCH 18/28] Add gradient accumulation integration test --- tests/integration_tests.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index e4518265a2..212643bd84 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -494,6 +494,20 @@ def build_test_list(): "Float8 emulation test", "float8_emulation", ), + OverrideDefinitions( + [ + [ + # Default local batch size = 8, and `ngpu=2`, so + # default global batch size = 8 * 2 = 16. + # To achieve 2 gradient accumulation steps, multiply + # default global batch size by 2. 16 * 2 = 32. + "--training.global_batch_size 32", + ], + ], + "Gradient accumulation", + "gradient_accumulation", + ngpu=2, + ), ] return integration_tests_flavors From 35fdc798e15f44a2a1131312e6e6b8d92fff0905 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 13:49:48 +0200 Subject: [PATCH 19/28] Fix FLUX trainer --- torchtitan/experiments/flux/train.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index cf83600f07..2903ff9145 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Optional +from typing import Iterable, Optional import torch from torch.distributed.fsdp import FSDPModule @@ -81,7 +81,10 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ) - def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + input_dict, labels = self.next_batch(data_iterator) # generate t5 and clip embeddings input_dict["image"] = labels input_dict = preprocess_data( From 12898e471a2c57eb71a7ba3e96671a179c683b66 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 13:57:10 +0200 Subject: [PATCH 20/28] Refactor FLUX train step ... toward `forward_backward_step` design. --- torchtitan/experiments/flux/train.py | 33 +++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 2903ff9145..09c510b3f6 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -81,10 +81,9 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ) - def train_step( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ): - input_dict, labels = self.next_batch(data_iterator) # generate t5 and clip embeddings input_dict["image"] = labels input_dict = preprocess_data( @@ -97,18 +96,11 @@ def train_step( ) labels = input_dict["img_encodings"] - self.optimizers.zero_grad() - # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. - model_parts = self.model_parts - assert len(self.model_parts) == 1 # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not model = self.model_parts[0] - world_mesh = self.world_mesh - parallel_dims = self.parallel_dims - # image in latent space transformed by self.auto_encoder clip_encodings = input_dict["clip_encodings"] t5_encodings = input_dict["t5_encodings"] @@ -152,6 +144,27 @@ def train_step( del (pred, noise, target) loss.backward() + return loss + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + input_dict, labels = self.next_batch(data_iterator) + + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + assert len(self.model_parts) == 1 + # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not + model = self.model_parts[0] + + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + loss = self.forward_backward_step(input_dict, labels) + dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], self.job_config.training.max_norm, From a2d8c26048b688b369b35feb9787e76b01040bc5 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 13:59:24 +0200 Subject: [PATCH 21/28] Use fixed local batch size --- tests/integration_tests.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index 212643bd84..c2d0cf7372 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -497,10 +497,11 @@ def build_test_list(): OverrideDefinitions( [ [ - # Default local batch size = 8, and `ngpu=2`, so - # default global batch size = 8 * 2 = 16. + # Local batch size = 8, and `ngpu=2`, so default + # global batch size = 8 * 2 = 16. # To achieve 2 gradient accumulation steps, multiply # default global batch size by 2. 16 * 2 = 32. + "--training.local_batch_size 8", "--training.global_batch_size 32", ], ], From 40802ad3b723843e34e59174dcd8f5fa0be97908 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 14:00:30 +0200 Subject: [PATCH 22/28] Fix typo --- torchtitan/config_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index eb34949ea6..4300c3bb8b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -338,7 +338,7 @@ class Parallelism: pipeline_parallel_microbatch_size: int = 1 """ The size of each pipeline parallel microbatch (default 1). - This value is used to compute the total number of microbatches by dividing local batch_size with + This value is used to compute the total number of microbatches by dividing local_batch_size with pipeline_parallel_microbatch_size. The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ From af0a5edacc0cce42f8276d39e346b0ab1c724a39 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 14:02:32 +0200 Subject: [PATCH 23/28] Move custom `StopIteration` exception --- torchtitan/components/dataloader.py | 6 ++++++ torchtitan/train.py | 7 +------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index c3eecc99dc..2b555a91bb 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -17,6 +17,12 @@ from torchtitan.tools.logging import logger +class DataloaderStopIteration(StopIteration): + """An exception that indicates dataloader exhaustion.""" + + pass + + class BaseDataLoader(Stateful, ABC): """Base class for all dataloaders. diff --git a/torchtitan/train.py b/torchtitan/train.py index 37549ab968..17594e7d44 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -16,6 +16,7 @@ import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.dataloader import DataloaderStopIteration from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, @@ -32,12 +33,6 @@ ) -class DataloaderStopIteration(StopIteration): - """An exception that indicates dataloader exhaustion.""" - - pass - - class Trainer(torch.distributed.checkpoint.stateful.Stateful): job_config: JobConfig gc_handler: utils.GarbageCollection From a29b59b0d5295adf89f2c18d0eac2eed10239d11 Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 14:03:11 +0200 Subject: [PATCH 24/28] Fix log type --- torchtitan/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 17594e7d44..39c539b085 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -493,7 +493,7 @@ def train(self): try: self.train_step(data_iterator) except DataloaderStopIteration: - logger.info("Ran out of data; last step was canceled.") + logger.warning("Ran out of data; last step was canceled.") break self.checkpointer.save( self.step, force=(self.step == job_config.training.steps) From 6b0efcaf3469eef0d428a2b5ceca3b92bd76fc4e Mon Sep 17 00:00:00 2001 From: janEbert Date: Wed, 4 Jun 2025 14:09:39 +0200 Subject: [PATCH 25/28] Add docstring to rescaled loss function --- torchtitan/components/loss.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index 5f28ad1d3a..6262564064 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -31,6 +31,10 @@ def build_cross_entropy_loss(job_config: JobConfig): def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps): + """Add a mean reduction over `accumulation_steps` to the given + `unwrapped_loss_fn`. + """ + @functools.wraps(unwrapped_loss_fn) def accumulated_loss_fn(*args, **kwargs): loss = unwrapped_loss_fn(*args, **kwargs) From b810950a881372715273ea33cd1ed77504031055 Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 5 Jun 2025 15:39:01 +0200 Subject: [PATCH 26/28] Fix missing types --- torchtitan/experiments/flux/train.py | 2 +- torchtitan/train.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index 09c510b3f6..fb667c414a 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -83,7 +83,7 @@ def __init__(self, job_config: JobConfig): def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ): + ) -> torch.Tensor: # generate t5 and clip embeddings input_dict["image"] = labels input_dict = preprocess_data( diff --git a/torchtitan/train.py b/torchtitan/train.py index 39c539b085..dfa84183f4 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -359,7 +359,7 @@ def next_batch( def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ): + ) -> torch.Tensor: model_parts = self.model_parts parallel_dims = self.parallel_dims From 39b5087886b38d841956f066b488508d2b7aefc4 Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 5 Jun 2025 15:41:10 +0200 Subject: [PATCH 27/28] Refactor away `next_batch` method We now raise the `DataloaderStopIteration` from inside the `batch_generator` generator method. `next_batch` can thus be removed as its only purpose at this point for raising the custom exception upon iterator exhaustion. --- torchtitan/experiments/flux/train.py | 2 +- torchtitan/train.py | 26 ++++++++++++-------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index fb667c414a..e982cb46ed 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -149,7 +149,7 @@ def forward_backward_step( def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): - input_dict, labels = self.next_batch(data_iterator) + input_dict, labels = next(data_iterator) self.optimizers.zero_grad() diff --git a/torchtitan/train.py b/torchtitan/train.py index dfa84183f4..91a981e8ad 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -329,8 +329,15 @@ def batch_generator( ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: """Returns an iterator that processes batches from the data iterator.""" device_type = utils.device_type - - for batch in iter(data_iterable): + data_iterator = iter(data_iterable) + + while True: + try: + batch = next(data_iterator) + except StopIteration as ex: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + raise DataloaderStopIteration() from ex data_load_start = time.perf_counter() input_dict, labels = batch self.metrics_processor.ntokens_since_last_log += labels.numel() @@ -346,17 +353,6 @@ def batch_generator( yield input_dict, labels - def next_batch( - self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] - ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: - try: - input_dict, labels = next(data_iterator) - except StopIteration as ex: - # If data runs out during gradient accumulation, that - # entire step will not be executed. - raise DataloaderStopIteration() from ex - return input_dict, labels - def forward_backward_step( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: @@ -421,8 +417,10 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + # If data runs out during gradient accumulation, that + # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): - input_dict, labels = self.next_batch(data_iterator) + input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) self.accumulated_losses.append(loss.detach()) From f4af76dd7b8a5c6938945fb21ffd331752d3f892 Mon Sep 17 00:00:00 2001 From: janEbert Date: Thu, 5 Jun 2025 15:46:26 +0200 Subject: [PATCH 28/28] Refactor `accumulated_losses` Move from dataclass attributes to method-local variable. --- torchtitan/train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index 91a981e8ad..bba1f01e2c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -41,7 +41,6 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh gradient_accumulation_steps: int - accumulated_losses: list[torch.Tensor] dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor @@ -210,7 +209,6 @@ def __init__(self, job_config: JobConfig): self.loss_fn = rescale_accumulated_loss( self.loss_fn, self.gradient_accumulation_steps ) - self.accumulated_losses = [] # apply parallelisms and initialization if parallel_dims.pp_enabled: @@ -417,12 +415,13 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for microbatch in range(self.gradient_accumulation_steps): input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) - self.accumulated_losses.append(loss.detach()) + accumulated_losses.append(loss.detach()) dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], @@ -435,8 +434,7 @@ def train_step( self.lr_schedulers.step() # Reduce the data collected over gradient accumulation steps. - loss = torch.sum(torch.stack(self.accumulated_losses)) - self.accumulated_losses.clear() + loss = torch.sum(torch.stack(accumulated_losses)) # log metrics if not self.metrics_processor.should_log(self.step):