Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7c55f07
Refactor batch loss and grad calculation
janEbert Dec 3, 2024
83de5dc
Support gradient accumulation
janEbert Dec 3, 2024
c9396c2
Run pre-commit hooks
janEbert May 29, 2025
1da4106
Change `global_batch_size` type to `int`
janEbert May 29, 2025
82b5e5d
Do not save checkpoint when data ran out
janEbert Jun 3, 2025
4a30b74
Raise custom exception upon data depletion
janEbert Jun 3, 2025
6cd2e31
Rename "batch size" to "local batch size"
janEbert Jun 3, 2025
45fbe0a
Rename `batch_backward` to `forward_backward_step`
janEbert Jun 3, 2025
be23194
Refactor loss function gradient accumulation wrap
janEbert Jun 3, 2025
5ae21d7
Do not modify `job_config.global_batch_size`
janEbert Jun 3, 2025
8cf71d6
Add comment on default gradient accumulation step
janEbert Jun 3, 2025
b485012
Move gradient accumulation derivation logic
janEbert Jun 3, 2025
12d274d
Remove redundant shortcut variables
janEbert Jun 3, 2025
d4dd122
Improve readability
janEbert Jun 3, 2025
a7f4c80
Add `gradient_accumulation_step` to dataclass
janEbert Jun 3, 2025
7003eb9
Move `accumulated_losses` to `Trainer`
janEbert Jun 3, 2025
266cffa
Apply pre-commit hooks
janEbert Jun 3, 2025
072b9b4
Add gradient accumulation integration test
janEbert Jun 3, 2025
35fdc79
Fix FLUX trainer
janEbert Jun 4, 2025
12898e4
Refactor FLUX train step
janEbert Jun 4, 2025
a2d8c26
Use fixed local batch size
janEbert Jun 4, 2025
40802ad
Fix typo
janEbert Jun 4, 2025
af0a5ed
Move custom `StopIteration` exception
janEbert Jun 4, 2025
a29b59b
Fix log type
janEbert Jun 4, 2025
6b0efca
Add docstring to rescaled loss function
janEbert Jun 4, 2025
b810950
Fix missing types
janEbert Jun 5, 2025
39b5087
Refactor away `next_batch` method
janEbert Jun 5, 2025
f4af76d
Refactor `accumulated_losses`
janEbert Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/converging.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ This note clarifies the recommended practices to follow when testing the loss co

## Guidelines

To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.batch_size`) in the toml config, and at the same time fix the data parallel degree.
To validate the correctness of a distributed training technique, one should try to **keep the determinism in the input data to minimize the differences it could cause**. To make sure the global batch size and in general #tokens per iteration stay the same, one can fix the local batch size (`training.local_batch_size`) in the toml config, and at the same time fix the data parallel degree.

If the technique is a parallelism (TP/PP/CP/etc)
- The control set is a 1D FSDP job on `dp` GPUs (or any other verified setups), with a trusted training config (e.g. those under train_configs).
Expand Down Expand Up @@ -40,7 +40,7 @@ Results are obtained on 2025/01/21, with the latest `torch`, `torchao`, and `tor

### Setup
- Base config: [torchtitan/models/llama3/train_configs/llama3_8b.toml](../torchtitan/models/llama3/train_configs/llama3_8b.toml)
- `training.batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
- `training.local_batch_size = 4`, which is a minimum for Pipeline Parallel with `pipeline_parallel_degree = 2` and `pipeline_parallel_schedule = "Interleaved1F1B"`
- `training.data_parallel_shard_degree = 8`, resulting in global batch size 32
- `training.steps = 3000`, `lr_scheduler.warmup_steps = 600`

Expand Down
4 changes: 2 additions & 2 deletions scripts/estimate/estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,13 @@ def estimate_memory(job_config: JobConfig):
torch.randint(
0,
model_args.vocab_size,
(job_config.training.batch_size, model_args.max_seq_len),
(job_config.training.local_batch_size, model_args.max_seq_len),
device="cuda",
),
torch.randint(
0,
model_args.vocab_size,
(job_config.training.batch_size, model_args.max_seq_len),
(job_config.training.local_batch_size, model_args.max_seq_len),
device="cuda",
),
)
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,21 @@ def build_test_list():
"Float8 emulation test",
"float8_emulation",
),
OverrideDefinitions(
[
[
# Local batch size = 8, and `ngpu=2`, so default
# global batch size = 8 * 2 = 16.
# To achieve 2 gradient accumulation steps, multiply
# default global batch size by 2. 16 * 2 = 32.
"--training.local_batch_size 8",
"--training.global_batch_size 32",
],
],
"Gradient accumulation",
"gradient_accumulation",
ngpu=2,
),
]
return integration_tests_flavors

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/test_dataset_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _build_dataloader(self, dataset_name, batch_size, seq_len, world_size, rank)
[
"--training.dataset",
dataset_name,
"--training.batch_size",
"--training.local_batch_size",
str(batch_size),
"--training.seq_len",
str(seq_len),
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/components/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from torchtitan.tools.logging import logger


class DataloaderStopIteration(StopIteration):
"""An exception that indicates dataloader exhaustion."""

pass


class BaseDataLoader(Stateful, ABC):
"""Base class for all dataloaders.

Expand Down
14 changes: 14 additions & 0 deletions torchtitan/components/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import functools
from typing import Callable, TypeAlias

import torch
Expand All @@ -27,3 +28,16 @@ def build_cross_entropy_loss(job_config: JobConfig):
logger.info("Compiling the loss function with torch.compile")
loss_fn = torch.compile(loss_fn)
return loss_fn


def rescale_accumulated_loss(unwrapped_loss_fn, accumulation_steps):
"""Add a mean reduction over `accumulation_steps` to the given
`unwrapped_loss_fn`.
"""

@functools.wraps(unwrapped_loss_fn)
def accumulated_loss_fn(*args, **kwargs):
loss = unwrapped_loss_fn(*args, **kwargs)
return loss / accumulation_steps

return accumulated_loss_fn
11 changes: 8 additions & 3 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,13 @@ class Training:
loaded from this path instead of downloaded.
"""

batch_size: int = 8
"""Batch size"""
local_batch_size: int = 8
"""Local batch size (i.e., per-device batch size)"""

global_batch_size: int = -1
"""
Global batch size (defaults to `training.local_batch_size * data-parallel degree`)
"""

seq_len: int = 2048
"""Sequence length"""
Expand Down Expand Up @@ -333,7 +338,7 @@ class Parallelism:
pipeline_parallel_microbatch_size: int = 1
"""
The size of each pipeline parallel microbatch (default 1).
This value is used to compute the total number of microbatches by dividing batch_size with
This value is used to compute the total number of microbatches by dividing local_batch_size with
pipeline_parallel_microbatch_size.
The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size.
"""
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/datasets/hf_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def build_hf_dataloader(
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
batch_size = job_config.training.local_batch_size
seq_len = job_config.training.seq_len

hf_ds = HuggingFaceDataset(
Expand Down
4 changes: 2 additions & 2 deletions torchtitan/distributed/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,11 @@ def build_pipeline_schedule(

looped_schedule = issubclass(schedule_class, PipelineScheduleMulti)
microbatch_size = job_config.parallelism.pipeline_parallel_microbatch_size
batch_size = job_config.training.batch_size
batch_size = job_config.training.local_batch_size
# validate that the batch size is divisible by the microbatch_size otherwise we'll hang or error during training
if batch_size % microbatch_size != 0:
raise ValueError(
f"Batch size {job_config.training.batch_size} must be divisible by number of microbatches {n_microbatches}. "
f"Batch size {job_config.training.local_batch_size} must be divisible by number of microbatches {n_microbatches}. "
"Update the config arguments for either batch_size or pipeline_parallel_microbatch_size."
)
n_microbatches = batch_size // microbatch_size
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ decay_type = "linear"
lr_min = 0.1

[training]
batch_size = 2 # 8
local_batch_size = 2 # 8
seq_len = 1024 # 2048
max_norm = 1.0 # grad norm clipping
steps = 200
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/deepseek_v3/train_ds_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def run_full_model(
# model.setup_symm_mem(torch.bfloat16, device)

torch.manual_seed(ep_rank)
bs = config.training.batch_size # * microbatches # 4
bs = config.training.local_batch_size # * microbatches # 4
seqlen = config.training.seq_len # 128

# metrics manager
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/dataset/flux_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def build_flux_dataloader(
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
batch_size = job_config.training.local_batch_size

t5_tokenizer, clip_tokenizer = build_flux_tokenizer(job_config)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _test_flux_dataloader(self, dataset_name):
str(256),
"--training.dataset",
dataset_name,
"--training.batch_size",
"--training.local_batch_size",
str(batch_size),
"--training.seed",
"0",
Expand Down
34 changes: 25 additions & 9 deletions torchtitan/experiments/flux/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
from typing import Optional
from typing import Iterable, Optional

import torch
from torch.distributed.fsdp import FSDPModule
Expand Down Expand Up @@ -81,7 +81,9 @@ def __init__(self, job_config: JobConfig):
job_config=job_config,
)

def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
def forward_backward_step(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
# generate t5 and clip embeddings
input_dict["image"] = labels
input_dict = preprocess_data(
Expand All @@ -94,18 +96,11 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
)
labels = input_dict["img_encodings"]

self.optimizers.zero_grad()

# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
model_parts = self.model_parts
assert len(self.model_parts) == 1
# explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not
model = self.model_parts[0]

world_mesh = self.world_mesh
parallel_dims = self.parallel_dims

# image in latent space transformed by self.auto_encoder
clip_encodings = input_dict["clip_encodings"]
t5_encodings = input_dict["t5_encodings"]
Expand Down Expand Up @@ -149,6 +144,27 @@ def train_step(self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor):
del (pred, noise, target)
loss.backward()

return loss

def train_step(
self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]]
):
input_dict, labels = next(data_iterator)

self.optimizers.zero_grad()

# Keep these variables local to shorten the code as these are
# the major variables that are used in the training loop.
model_parts = self.model_parts
assert len(self.model_parts) == 1
# explicitely convert flux model to be Bfloat16 no matter FSDP is applied or not
model = self.model_parts[0]

world_mesh = self.world_mesh
parallel_dims = self.parallel_dims

loss = self.forward_backward_step(input_dict, labels)

dist_utils.clip_grad_norm_(
[p for m in model_parts for p in m.parameters()],
self.job_config.training.max_norm,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/flux/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ warmup_steps = 1 # 10% warmup steps
decay_ratio = 0.0 # no decay, stay stable during training

[training]
batch_size = 4
local_batch_size = 4
max_norm = 2.0 # grad norm clipping
steps = 10
compile = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0 # no decay

[training]
batch_size = 32
local_batch_size = 32
max_norm = 1.0 # grad norm clipping
steps = 30_000
compile = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ warmup_steps = 3_000 # lr scheduler warm up, normally 20% of the train steps
decay_ratio = 0.0 # no decay

[training]
batch_size = 64
local_batch_size = 64
max_norm = 1.0 # grad norm clipping
steps = 30_000
compile = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ decay_type = "linear"
lr_min = 0.1

[training]
batch_size = 8
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ warmup_steps = 600
lr_min = 0.1

[training]
batch_size = 1
local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ warmup_steps = 600
lr_min = 0.1

[training]
batch_size = 8
local_batch_size = 8
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/multimodal/check_padding_mm.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def main(
[
"--training.dataset",
dataset,
"--training.batch_size",
"--training.local_batch_size",
str(batch_size),
"--training.seq_len",
str(seq_len),
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/experiments/multimodal/mm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def build_mm_dataloader(
"""Build a data loader for HuggingFace datasets."""
dataset_name = job_config.training.dataset
dataset_path = job_config.training.dataset_path
batch_size = job_config.training.batch_size
batch_size = job_config.training.local_batch_size
seq_len = job_config.training.seq_len
pad_max_tiles = 4 # TODO(tj.solergibert) Add `pad_max_tiles` to JobConfig
padding_idx = 128004 # TODO(tj.solergibert) Add `padding_idx` to JobConfig
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ decay_type = "linear"
lr_min = 0.0

[training]
batch_size = 8
local_batch_size = 8
seq_len = 2048
max_norm = 1.0 # grad norm clipping
steps = 10
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ eps = 1e-8
warmup_steps = 600 # lr scheduler warm up, normally 20% of the train steps

[training]
batch_size = 2
local_batch_size = 2
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 3000
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ eps = 1e-8
warmup_steps = 200 # lr scheduler warm up, normally 20% of the train steps

[training]
batch_size = 8
local_batch_size = 8
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama3/train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ eps = 1e-8
warmup_steps = 200 # lr scheduler warm up

[training]
batch_size = 1
local_batch_size = 1
seq_len = 8192
max_norm = 1.0 # grad norm clipping
steps = 1000
Expand Down
Loading
Loading