Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,24 @@ set -ex
# use envs as local overwrites for convenience
# e.g.
# LOG_RANK=0,1 NGPU=4 ./run_train.sh
# DRY_RUN=1 ./run_train.sh # for config validation without GPU
NGPU=${NGPU:-"8"}
export LOG_RANK=${LOG_RANK:-0}
CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"}
TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"}
DRY_RUN=${DRY_RUN:-0}

TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"}

PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
if [ "$DRY_RUN" = "1" ]; then
# Dry run mode: validate configuration without GPU/distributed setup
echo "Running in DRY RUN mode - configuration validation only"
python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@"
else
# Normal training with torchrun
PYTORCH_ALLOC_CONF="expandable_segments:True" \
TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \
torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \
--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \
-m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@"
fi
156 changes: 156 additions & 0 deletions scripts/dry_run.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Would you consider putting this under torchtitan/tools/dry_run.py (or other 2nd level directory under torchtitan), or scripts/dry_run.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, let's put it under scripts/dry_run.py for now. We should investigate how to merge it back to train.py with LocalTensor or fake backend anyway.

Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""
Dry run trainer for fast configuration validation without GPU/distributed setup.

This module provides a lightweight trainer that validates job configurations,
model architecture, and dataloader setup without requiring GPU resources or
distributed initialization. Useful for rapid iteration on configuration files
and CI/CD validation pipelines.
"""

import os
import sys

# Add parent directory to path to import torchtitan
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch

import torchtitan.protocols.train_spec as train_spec_module
from torchtitan.config import JobConfig, TORCH_DTYPE_MAP
from torchtitan.tools import utils
from torchtitan.tools.logging import logger
from torchtitan.train import main, Trainer


class DryRunTrainer(Trainer):
"""
A lightweight trainer that validates configurations without GPU allocation.

This trainer performs comprehensive validation of the training configuration
without allocating GPU resources or initializing distributed setup. It validates:

- Configuration file parsing and structure
- Model architecture (constructed on meta device)
- Tokenizer initialization
- Dataloader configuration
- Parallelism settings
- Model converters (if specified)

Unlike the regular Trainer, this does not:
- Allocate GPU memory
- Initialize distributed process groups
- Create optimizers or learning rate schedulers
- Set up checkpointing or metrics
- Run any actual training

Args:
job_config: JobConfig containing all training configuration parameters

Note:
Validation completes immediately after initialization. No training loop is executed.
All operations use CPU and meta devices for zero-cost validation.
"""

def __init__(self, job_config: JobConfig):
torch._C._log_api_usage_once("torchtitan.dry_run")

self.job_config = job_config

logger.info(f"Starting job: {job_config.job.description}")
logger.info("DRY RUN MODE - Configuration validation only")

# Use CPU device (no GPU required)
self.device = torch.device("cpu")

# Log and validate config
job_config.maybe_log()
logger.info("Configuration parsed successfully")

# Get train spec
self.train_spec = train_spec_module.get_train_spec(job_config.model.name)
logger.info(f"Train spec loaded for model: {job_config.model.name}")

# Build tokenizer
self.tokenizer = (
self.train_spec.build_tokenizer_fn(job_config)
if self.train_spec.build_tokenizer_fn is not None
else None
)
if self.tokenizer:
logger.info("Tokenizer built successfully")

# Validate model configuration
model_args = self.train_spec.model_args[job_config.model.flavor]
model_args.update_from_config(job_config)
self.model_args = model_args

logger.info(
f"Model args validated: {job_config.model.name} {job_config.model.flavor}"
)

# Build model on meta device (validates architecture without memory allocation)
logger.info("Validating model architecture...")
with (
torch.device("meta"),
utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]),
):
model = self.train_spec.model_cls(model_args)

# Calculate and log model size
model_param_count, _ = model_args.get_nparams_and_flops(
model, job_config.training.seq_len
)
logger.info(
f"Model architecture validated: {job_config.model.name} "
f"with {model_param_count:,} parameters"
)

# Validate dataloader configuration (build with minimal params)
logger.info("Validating dataloader configuration...")
try:
# Use dp_world_size=1 and dp_rank=0 for dry run
dataloader = self.train_spec.build_dataloader_fn(
dp_world_size=1,
dp_rank=0,
tokenizer=self.tokenizer,
job_config=job_config,
)
logger.info("Dataloader configuration validated successfully")
except Exception as e:
logger.warning(f"Dataloader validation encountered issue: {e}")
logger.info(
"Note: Some dataloader issues may only appear with actual data paths"
)

# Validate model converters if specified
if job_config.model.converters:
logger.info(f"Model converters specified: {job_config.model.converters}")

# Validate parallelism configuration
parallelism_config = job_config.parallelism
logger.info(
f"Parallelism config: "
f"DP-shard={parallelism_config.data_parallel_shard_degree}, "
f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, "
f"TP={parallelism_config.tensor_parallel_degree}, "
f"PP={parallelism_config.pipeline_parallel_degree}, "
f"CP={parallelism_config.context_parallel_degree}"
)

# Summary
logger.info("=" * 80)
logger.info("DRY RUN VALIDATION COMPLETE")
logger.info("=" * 80)
logger.info("All configurations validated successfully!")
logger.info("Configuration is ready for training execution.")
logger.info("=" * 80)


if __name__ == "__main__":
main(DryRunTrainer)
4 changes: 3 additions & 1 deletion torchtitan/distributed/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh):
return

if not (job_config.compile.enable and "model" in job_config.compile.components):
raise RuntimeError("Async TP requires --training.compile")
raise RuntimeError(
"Async TP requires 'model' in --compile.components and --compile.enable"
)

from torch.distributed._symmetric_memory import enable_symm_mem_for_group

Expand Down
4 changes: 2 additions & 2 deletions torchtitan/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -698,9 +698,9 @@ def load_state_dict(self, state_dict: dict[str, Any]):
self.ntokens_seen = state_dict["ntokens_seen"]

def close(self) -> None:
if self.checkpointer:
if hasattr(self, "checkpointer") and self.checkpointer:
self.checkpointer.close()
if self.metrics_processor:
if hasattr(self, "metrics_processor") and self.metrics_processor:
self.metrics_processor.close()


Expand Down
Loading