diff --git a/docs/converging.md b/docs/converging.md index 2ba1575bda..3bb62d9353 100644 --- a/docs/converging.md +++ b/docs/converging.md @@ -12,7 +12,7 @@ This note clarifies the recommended practices to follow when testing the loss co ## Guidelines -To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.batch_size`) in the toml config, and at the same time fix the data parallel degree. +To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.local_batch_size`) in the toml config, and at the same time fix the data parallel degree. If the technique is a parallelism (TP/PP/CP/etc) - The control set is a 1D FSDP job on `dp` GPUs (or any other verified setups), with a trusted training config (e.g. those under train_configs). @@ -40,7 +40,7 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor ### Setup - Base config: [torchtitan/models/llama3/train_configs/llama3_8b.toml](../torchtitan/models/llama3/train_configs/llama3_8b.toml) -- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` +- `training.local_batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"` - `training.data_parallel_shard_degree = 8`, resulting in global batch size 32 - `training.steps = 3000`, `lr_scheduler.warmup_steps = 600` diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index bef5ad6970..85a058a0fb 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -130,13 +130,13 @@ def estimate_memory(job_config: JobConfig): torch.randint( 0, model_args.vocab_size, - (job_config.training.batch_size, model_args.max_seq_len), + (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), torch.randint( 0, model_args.vocab_size, - (job_config.training.batch_size, model_args.max_seq_len), + (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), ) diff --git a/tests/integration_tests.py b/tests/integration_tests.py index e4518265a2..c2d0cf7372 100755 --- a/tests/integration_tests.py +++ b/tests/integration_tests.py @@ -494,6 +494,21 @@ def build_test_list(): "Float8 emulation test", "float8_emulation", ), + OverrideDefinitions( + [ + [ + # Local batch size = 8, and `ngpu=2`, so default + # global batch size = 8 * 2 = 16. + # To achieve 2 gradient accumulation steps, multiply + # default global batch size by 2. 16 * 2 = 32. + "--training.local_batch_size 8", + "--training.global_batch_size 32", + ], + ], + "Gradient accumulation", + "gradient_accumulation", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/tests/unit_tests/test_dataset_checkpointing.py b/tests/unit_tests/test_dataset_checkpointing.py index 1eae8f1962..00998fc495 100644 --- a/tests/unit_tests/test_dataset_checkpointing.py +++ b/tests/unit_tests/test_dataset_checkpointing.py @@ -64,7 +64,7 @@ def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank) [ "--training.dataset", dataset_name, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seq_len", str(seq_len), diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index c3eecc99dc..2b555a91bb 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -17,6 +17,12 @@ from torchtitan.tools.logging import logger +class DataloaderStopIteration(StopIteration): + """An exception that indicates dataloader exhaustion.""" + + pass + + class BaseDataLoader(Stateful, ABC): """Base class for all dataloaders. diff --git a/torchtitan/components/loss.py b/torchtitan/components/loss.py index e2df56d52e..6262564064 100644 --- a/torchtitan/components/loss.py +++ b/torchtitan/components/loss.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import functools from typing import Callable, TypeAlias import torch @@ -27,3 +28,16 @@ def build_cross_entropy_loss(job_config: JobConfig): logger.info("Compiling the loss function with torch.compile") loss_fn = torch.compile(loss_fn) return loss_fn + + +def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps): + """Add a mean reduction over `accumulation_steps` to the given + `unwrapped_loss_fn`. + """ + + @functools.wraps(unwrapped_loss_fn) + def accumulated_loss_fn(*args, **kwargs): + loss = unwrapped_loss_fn(*args, **kwargs) + return loss / accumulation_steps + + return accumulated_loss_fn diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index d8299f91a9..4300c3bb8b 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -189,8 +189,13 @@ class Training: loaded from this path instead of downloaded. """ - batch_size: int = 8 - """Batch size""" + local_batch_size: int = 8 + """Local batch size (i.e., per-device batch size)""" + + global_batch_size: int = -1 + """ + Global batch size (defaults to `training.local_batch_size * data-parallel degree`) + """ seq_len: int = 2048 """Sequence length""" @@ -333,7 +338,7 @@ class Parallelism: pipeline_parallel_microbatch_size: int = 1 """ The size of each pipeline parallel microbatch (default 1). - This value is used to compute the total number of microbatches by dividing batch_size with + This value is used to compute the total number of microbatches by dividing local_batch_size with pipeline_parallel_microbatch_size. The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ diff --git a/torchtitan/datasets/hf_datasets.py b/torchtitan/datasets/hf_datasets.py index 350feac695..023b4a29e9 100644 --- a/torchtitan/datasets/hf_datasets.py +++ b/torchtitan/datasets/hf_datasets.py @@ -174,7 +174,7 @@ def build_hf_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len hf_ds = HuggingFaceDataset( diff --git a/torchtitan/distributed/pipeline.py b/torchtitan/distributed/pipeline.py index 973c58c19c..366021a7fc 100644 --- a/torchtitan/distributed/pipeline.py +++ b/torchtitan/distributed/pipeline.py @@ -150,11 +150,11 @@ def build_pipeline_schedule( looped_schedule = issubclass(schedule_class, PipelineScheduleMulti) microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size # validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training if batch_size % microbatch_size != 0: raise ValueError( - f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. " + f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. " "Update the config arguments for either batch_size or pipeline_parallel_microbatch_size." ) n_microbatches = batch_size // microbatch_size diff --git a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml index efd54c48f8..1402e57e1e 100644 --- a/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml +++ b/torchtitan/experiments/deepseek_v3/train_configs/deepseek_v2.toml @@ -38,7 +38,7 @@ decay_type = "linear" lr_min = 0.1 [training] -batch_size = 2 # 8 +local_batch_size = 2 # 8 seq_len = 1024 # 2048 max_norm = 1.0 # grad norm clipping steps = 200 diff --git a/torchtitan/experiments/deepseek_v3/train_ds_real.py b/torchtitan/experiments/deepseek_v3/train_ds_real.py index c1bbc4ff0a..398360a6ee 100644 --- a/torchtitan/experiments/deepseek_v3/train_ds_real.py +++ b/torchtitan/experiments/deepseek_v3/train_ds_real.py @@ -145,7 +145,7 @@ def run_full_model( # model.setup_symm_mem(torch.bfloat16, device) torch.manual_seed(ep_rank) - bs = config.training.batch_size # * microbatches # 4 + bs = config.training.local_batch_size # * microbatches # 4 seqlen = config.training.seq_len # 128 # metrics manager diff --git a/torchtitan/experiments/flux/dataset/flux_dataset.py b/torchtitan/experiments/flux/dataset/flux_dataset.py index 7e1de9dcea..9dcd641498 100644 --- a/torchtitan/experiments/flux/dataset/flux_dataset.py +++ b/torchtitan/experiments/flux/dataset/flux_dataset.py @@ -278,7 +278,7 @@ def build_flux_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config) diff --git a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py index 8a354eca67..8072032115 100644 --- a/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py +++ b/torchtitan/experiments/flux/tests/unit_tests/test_flux_dataloader.py @@ -40,7 +40,7 @@ def _test_flux_dataloader(self, dataset_name): str(256), "--training.dataset", dataset_name, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seed", "0", diff --git a/torchtitan/experiments/flux/train.py b/torchtitan/experiments/flux/train.py index cf83600f07..e982cb46ed 100644 --- a/torchtitan/experiments/flux/train.py +++ b/torchtitan/experiments/flux/train.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Optional +from typing import Iterable, Optional import torch from torch.distributed.fsdp import FSDPModule @@ -81,7 +81,9 @@ def __init__(self, job_config: JobConfig): job_config=job_config, ) - def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: # generate t5 and clip embeddings input_dict["image"] = labels input_dict = preprocess_data( @@ -94,18 +96,11 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): ) labels = input_dict["img_encodings"] - self.optimizers.zero_grad() - # Keep these variables local to shorten the code as these are # the major variables that are used in the training loop. - model_parts = self.model_parts - assert len(self.model_parts) == 1 # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not model = self.model_parts[0] - world_mesh = self.world_mesh - parallel_dims = self.parallel_dims - # image in latent space transformed by self.auto_encoder clip_encodings = input_dict["clip_encodings"] t5_encodings = input_dict["t5_encodings"] @@ -149,6 +144,27 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): del (pred, noise, target) loss.backward() + return loss + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + input_dict, labels = next(data_iterator) + + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + model_parts = self.model_parts + assert len(self.model_parts) == 1 + # explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not + model = self.model_parts[0] + + world_mesh = self.world_mesh + parallel_dims = self.parallel_dims + + loss = self.forward_backward_step(input_dict, labels) + dist_utils.clip_grad_norm_( [p for m in model_parts for p in m.parameters()], self.job_config.training.max_norm, diff --git a/torchtitan/experiments/flux/train_configs/debug_model.toml b/torchtitan/experiments/flux/train_configs/debug_model.toml index 3bc4a4e236..596a513737 100644 --- a/torchtitan/experiments/flux/train_configs/debug_model.toml +++ b/torchtitan/experiments/flux/train_configs/debug_model.toml @@ -33,7 +33,7 @@ warmup_steps = 1 # 10% warmup steps decay_ratio = 0.0 # no decay, stay stable during training [training] -batch_size = 4 +local_batch_size = 4 max_norm = 2.0 # grad norm clipping steps = 10 compile = false diff --git a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml index 72e73e10b0..23dbb2b429 100644 --- a/torchtitan/experiments/flux/train_configs/flux_dev_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_dev_model.toml @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.0 # no decay [training] -batch_size = 32 +local_batch_size = 32 max_norm = 1.0 # grad norm clipping steps = 30_000 compile = false diff --git a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml index 2e66170ddf..1e2421e59a 100644 --- a/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml +++ b/torchtitan/experiments/flux/train_configs/flux_schnell_model.toml @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps decay_ratio = 0.0 # no decay [training] -batch_size = 64 +local_batch_size = 64 max_norm = 1.0 # grad norm clipping steps = 30_000 compile = false diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bb038e60c4..bc48d38090 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -37,7 +37,7 @@ decay_type = "linear" lr_min = 0.1 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index dfaa4114e0..f508968c8d 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -30,7 +30,7 @@ warmup_steps = 600 lr_min = 0.1 [training] -batch_size = 1 +local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index 41fc37efde..c899dd5087 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -30,7 +30,7 @@ warmup_steps = 600 lr_min = 0.1 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/experiments/multimodal/check_padding_mm.py b/torchtitan/experiments/multimodal/check_padding_mm.py index cc8b357286..7534dcdccf 100644 --- a/torchtitan/experiments/multimodal/check_padding_mm.py +++ b/torchtitan/experiments/multimodal/check_padding_mm.py @@ -35,7 +35,7 @@ def main( [ "--training.dataset", dataset, - "--training.batch_size", + "--training.local_batch_size", str(batch_size), "--training.seq_len", str(seq_len), diff --git a/torchtitan/experiments/multimodal/mm_dataset.py b/torchtitan/experiments/multimodal/mm_dataset.py index a29627aace..519272c74e 100644 --- a/torchtitan/experiments/multimodal/mm_dataset.py +++ b/torchtitan/experiments/multimodal/mm_dataset.py @@ -240,7 +240,7 @@ def build_mm_dataloader( """Build a data loader for HuggingFace datasets.""" dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path - batch_size = job_config.training.batch_size + batch_size = job_config.training.local_batch_size seq_len = job_config.training.seq_len pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig diff --git a/torchtitan/models/llama3/train_configs/debug_model.toml b/torchtitan/models/llama3/train_configs/debug_model.toml index 60916966cb..4d4426d0d0 100644 --- a/torchtitan/models/llama3/train_configs/debug_model.toml +++ b/torchtitan/models/llama3/train_configs/debug_model.toml @@ -39,7 +39,7 @@ decay_type = "linear" lr_min = 0.0 [training] -batch_size = 8 +local_batch_size = 8 seq_len = 2048 max_norm = 1.0 # grad norm clipping steps = 10 diff --git a/torchtitan/models/llama3/train_configs/llama3_405b.toml b/torchtitan/models/llama3/train_configs/llama3_405b.toml index 6adcd3d90a..175c41fc8f 100644 --- a/torchtitan/models/llama3/train_configs/llama3_405b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_405b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps [training] -batch_size = 2 +local_batch_size = 2 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 3000 diff --git a/torchtitan/models/llama3/train_configs/llama3_70b.toml b/torchtitan/models/llama3/train_configs/llama3_70b.toml index 582f5aff5c..257205c314 100644 --- a/torchtitan/models/llama3/train_configs/llama3_70b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_70b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps [training] -batch_size = 8 +local_batch_size = 8 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 diff --git a/torchtitan/models/llama3/train_configs/llama3_8b.toml b/torchtitan/models/llama3/train_configs/llama3_8b.toml index 6cfe61bc70..aa6f81f803 100644 --- a/torchtitan/models/llama3/train_configs/llama3_8b.toml +++ b/torchtitan/models/llama3/train_configs/llama3_8b.toml @@ -30,7 +30,7 @@ eps = 1e-8 warmup_steps = 200 # lr scheduler warm up [training] -batch_size = 1 +local_batch_size = 1 seq_len = 8192 max_norm = 1.0 # grad norm clipping steps = 1000 diff --git a/torchtitan/train.py b/torchtitan/train.py index 14edd70ad4..bba1f01e2c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -16,6 +16,8 @@ import torchtitan.components.ft as ft import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager +from torchtitan.components.dataloader import DataloaderStopIteration +from torchtitan.components.loss import rescale_accumulated_loss from torchtitan.components.metrics import ( build_metrics_processor, ensure_pp_loss_visible, @@ -38,6 +40,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): parallel_dims: ParallelDims train_spec: train_spec_module.TrainSpec world_mesh: torch.distributed.DeviceMesh + gradient_accumulation_steps: int dataloader: train_spec_module.BaseDataLoader metrics_processor: train_spec_module.MetricsProcessor @@ -183,6 +186,30 @@ def __init__(self, job_config: JobConfig): self.loss_fn = self.train_spec.build_loss_fn(job_config) + # verify batch sizes + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + global_batch_size = job_config.training.local_batch_size * dp_degree + assert global_batch_size > 0 + assert ( + global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + ), ( + f"global batch size must be multiple of local batch size times " + f"data-parallel degree ({global_batch_size} " + f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + ) + + # calculate gradient accumulation steps + self.gradient_accumulation_steps = global_batch_size // ( + job_config.training.local_batch_size * dp_degree + ) + assert self.gradient_accumulation_steps > 0 + self.loss_fn = rescale_accumulated_loss( + self.loss_fn, self.gradient_accumulation_steps + ) + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -287,8 +314,9 @@ def __init__(self, job_config: JobConfig): logger.info( "Trainer is initialized with " - f"local batch size {job_config.training.batch_size}, " - f"global batch size {job_config.training.batch_size * dp_degree}, " + f"local batch size {job_config.training.local_batch_size}, " + f"global batch size {global_batch_size}, " + f"gradient accumulation steps {self.gradient_accumulation_steps}, " f"sequence length {job_config.training.seq_len}, " f"total steps {job_config.training.steps} " f"(warmup {job_config.lr_scheduler.warmup_steps})." @@ -299,8 +327,15 @@ def batch_generator( ) -> Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]: """Returns an iterator that processes batches from the data iterator.""" device_type = utils.device_type - - for batch in iter(data_iterable): + data_iterator = iter(data_iterable) + + while True: + try: + batch = next(data_iterator) + except StopIteration as ex: + # If data runs out during gradient accumulation, that + # entire step will not be executed. + raise DataloaderStopIteration() from ex data_load_start = time.perf_counter() input_dict, labels = batch self.metrics_processor.ntokens_since_last_log += labels.numel() @@ -316,13 +351,10 @@ def batch_generator( yield input_dict, labels - def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): - self.optimizers.zero_grad() - - # Keep these variables local to shorten the code as these are - # the major variables that are used in the training loop. + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: model_parts = self.model_parts - world_mesh = self.world_mesh parallel_dims = self.parallel_dims # apply context parallelism if cp is enabled @@ -330,7 +362,7 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): inputs = input_dict["input"] optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( - cp_mesh=world_mesh["cp"], + cp_mesh=self.world_mesh["cp"], cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], cp_seq_dims=[1, 1] + [0 for _ in model_parts], cp_no_restore_buffers={inputs, labels}, @@ -372,8 +404,27 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): del pred loss.backward() + return loss + + def train_step( + self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] + ): + self.optimizers.zero_grad() + + # Keep these variables local to shorten the code as these are + # the major variables that are used in the training loop. + parallel_dims = self.parallel_dims + + accumulated_losses = [] + # If data runs out during gradient accumulation, that + # entire step will not be executed. + for microbatch in range(self.gradient_accumulation_steps): + input_dict, labels = next(data_iterator) + loss = self.forward_backward_step(input_dict, labels) + accumulated_losses.append(loss.detach()) + dist_utils.clip_grad_norm_( - [p for m in model_parts for p in m.parameters()], + [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, pp_mesh=self.world_mesh["pp"] if parallel_dims.pp_enabled else None, @@ -382,6 +433,9 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): self.optimizers.step() self.lr_schedulers.step() + # Reduce the data collected over gradient accumulation steps. + loss = torch.sum(torch.stack(accumulated_losses)) + # log metrics if not self.metrics_processor.should_log(self.step): return @@ -400,8 +454,8 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor): ) ft_pg = self.ft_manager.replicate_pg if use_ft_pg else None global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, self.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_max(loss, self.world_mesh["dp_cp"], ft_pg), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -428,12 +482,15 @@ def train(self): sync_every=job_config.fault_tolerance.sync_steps, ), ): - for inputs, labels in self.batch_generator(self.dataloader): - if self.step >= job_config.training.steps: - break + data_iterator = self.batch_generator(self.dataloader) + while self.step < job_config.training.steps: self.step += 1 self.gc_handler.run(self.step) - self.train_step(inputs, labels) + try: + self.train_step(data_iterator) + except DataloaderStopIteration: + logger.warning("Ran out of data; last step was canceled.") + break self.checkpointer.save( self.step, force=(self.step == job_config.training.steps) )