From cd47ff82ae76e033f0ef0afb3170cea9c3aae1fc Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 13:35:33 -0800 Subject: [PATCH 1/9] Update [ghstack-poisoned] --- run_train.sh | 2 +- scripts/dry_run.py | 159 -------------------------------- torchtitan/config/job_config.py | 3 + torchtitan/distributed/utils.py | 38 +++++++- torchtitan/train.py | 20 +++- 5 files changed, 57 insertions(+), 165 deletions(-) delete mode 100644 scripts/dry_run.py diff --git a/run_train.sh b/run_train.sh index 83319816fe..019ae2d7af 100755 --- a/run_train.sh +++ b/run_train.sh @@ -22,7 +22,7 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} if [ "$DRY_RUN" = "1" ]; then # Dry run mode: validate configuration without GPU/distributed setup echo "Running in DRY RUN mode - configuration validation only" - python scripts/dry_run.py --job.config_file ${CONFIG_FILE} "$@" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.local_tensor_mode else # Normal training with torchrun PYTORCH_ALLOC_CONF="expandable_segments:True" \ diff --git a/scripts/dry_run.py b/scripts/dry_run.py deleted file mode 100644 index fa8e1b4c17..0000000000 --- a/scripts/dry_run.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Dry run trainer for fast configuration validation without GPU/distributed setup. - -This module provides a lightweight trainer that validates job configurations, -model architecture, and dataloader setup without requiring GPU resources or -distributed initialization. Useful for rapid iteration on configuration files -and CI/CD validation pipelines. -""" - -import os -import sys - -# Add parent directory to path to import torchtitan -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -import torch - -import torchtitan.protocols.train_spec as train_spec_module -from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.tools import utils -from torchtitan.tools.logging import logger -from torchtitan.train import main, Trainer - - -class DryRunTrainer(Trainer): - """ - A lightweight trainer that validates configurations without GPU allocation. - - This trainer performs comprehensive validation of the training configuration - without allocating GPU resources or initializing distributed setup. It validates: - - - Configuration file parsing and structure - - Model architecture (constructed on meta device) - - Tokenizer initialization - - Dataloader configuration - - Parallelism settings - - Model converters (if specified) - - Unlike the regular Trainer, this does not: - - Allocate GPU memory - - Initialize distributed process groups - - Create optimizers or learning rate schedulers - - Set up checkpointing or metrics - - Run any actual training - - Args: - job_config: JobConfig containing all training configuration parameters - - Note: - Validation completes immediately after initialization. No training loop is executed. - All operations use CPU and meta devices for zero-cost validation. - """ - - def __init__(self, job_config: JobConfig): - torch._C._log_api_usage_once("torchtitan.dry_run") - - self.job_config = job_config - - logger.info(f"Starting job: {job_config.job.description}") - logger.info("DRY RUN MODE - Configuration validation only") - - # Use CPU device (no GPU required) - self.device = torch.device("cpu") - - # Log and validate config - job_config.maybe_log() - logger.info("Configuration parsed successfully") - - # Get train spec - self.train_spec = train_spec_module.get_train_spec(job_config.model.name) - logger.info(f"Train spec loaded for model: {job_config.model.name}") - - # Build tokenizer - self.tokenizer = ( - self.train_spec.build_tokenizer_fn(job_config) - if self.train_spec.build_tokenizer_fn is not None - else None - ) - if self.tokenizer: - logger.info("Tokenizer built successfully") - - # Validate model configuration - model_args = self.train_spec.model_args[job_config.model.flavor] - model_args.update_from_config(job_config) - self.model_args = model_args - - logger.info( - f"Model args validated: {job_config.model.name} {job_config.model.flavor}" - ) - - # Build model on meta device (validates architecture without memory allocation) - logger.info("Validating model architecture...") - with ( - torch.device("meta"), - utils.set_default_dtype(TORCH_DTYPE_MAP[job_config.training.dtype]), - ): - model = self.train_spec.model_cls(model_args) - - # Calculate and log model size - model_param_count, _ = model_args.get_nparams_and_flops( - model, job_config.training.seq_len - ) - logger.info( - f"Model architecture validated: {job_config.model.name} " - f"with {model_param_count:,} parameters" - ) - - # Validate dataloader configuration (build with minimal params) - logger.info("Validating dataloader configuration...") - try: - # Use dp_world_size=1 and dp_rank=0 for dry run - dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=1, - dp_rank=0, - tokenizer=self.tokenizer, - job_config=job_config, - ) - logger.info("Dataloader configuration validated successfully") - except Exception as e: - logger.warning(f"Dataloader validation encountered issue: {e}") - logger.info( - "Note: Some dataloader issues may only appear with actual data paths" - ) - - # Validate model converters if specified - if job_config.model.converters: - logger.info(f"Model converters specified: {job_config.model.converters}") - - # Validate parallelism configuration - parallelism_config = job_config.parallelism - logger.info( - f"Parallelism config: " - f"DP-shard={parallelism_config.data_parallel_shard_degree}, " - f"DP-replicate={parallelism_config.data_parallel_replicate_degree}, " - f"TP={parallelism_config.tensor_parallel_degree}, " - f"PP={parallelism_config.pipeline_parallel_degree}, " - f"CP={parallelism_config.context_parallel_degree}" - ) - - # Summary - logger.info("=" * 80) - logger.info("DRY RUN VALIDATION COMPLETE") - logger.info("=" * 80) - logger.info("All configurations validated successfully!") - logger.info("Configuration is ready for training execution.") - logger.info("=" * 80) - - def train(self): - return - - -if __name__ == "__main__": - main(DryRunTrainer) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 95588d2c3b..4c01e1b2b8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -791,6 +791,9 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" + local_tensor_mode: bool = False + """Local tensor mode, for debugging purposes. This is an experimental feature.""" + @dataclass class MemoryEstimation: diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index b209ddfd68..c8cfd7bc25 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -14,6 +14,7 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch import distributed as dist +from torch.distributed import _local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor @@ -258,12 +259,45 @@ def maybe_enable_amp( ) +def init_local_tensor_mode(world_size: int) -> int: + """Initialize local tensor mode for debugging purposes. + + Args: + world_size: The number of GPUs to simulate + + Returns: + The world size + """ + torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + return world_size + + def init_distributed( comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "", ranks: list[int] | None = None, -): +) -> int: + if comm_config.local_tensor_mode: + ngpu_str = os.environ.get("NGPU") + if ngpu_str is None: + raise ValueError( + "NGPU environment variable must be set when using local_tensor_mode" + ) + try: + world_size = int(ngpu_str) + except ValueError as e: + raise ValueError( + f"NGPU environment variable must be a valid integer, got: {ngpu_str}" + ) from e + return init_local_tensor_mode(world_size) + def _warn_overwrite_env(env, val): if env in os.environ: logger.warning( @@ -309,6 +343,8 @@ def _get_distributed_backend(enable_cpu_backend): _ranks=ranks if ranks is not None else [], ) + return torch.distributed.get_world_size() + def set_pg_timeouts(timeout, world_mesh): """ diff --git a/torchtitan/train.py b/torchtitan/train.py index 5cfab998b2..ec1ab5b008 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,9 +12,9 @@ import torch -from torch.distributed.elastic.multiprocessing.errors import record - import torchtitan.protocols.train_spec as train_spec_module + +from torch.distributed.elastic.multiprocessing.errors import record from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training @@ -208,6 +208,12 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) + # TODO(local_tensor): Remove this early return once LocalTensor supports + # init_weights().Currently skipping parallelism setup and model initialization + # in local tensor mode. + if job_config.comm.local_tensor_mode: + return + # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: @@ -360,13 +366,12 @@ def __init__(self, job_config: JobConfig): def init_distributed(self) -> ParallelDims: job_config = self.job_config - dist_utils.init_distributed( + world_size = dist_utils.init_distributed( job_config.comm, enable_cpu_backend=job_config.training.enable_cpu_offload, base_folder=job_config.job.dump_folder, ) - world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism return ParallelDims( @@ -718,6 +723,13 @@ def main(trainer_class: type[Trainer]) -> None: try: trainer = trainer_class(config) + # TODO(local_tensor): Remove this special case once LocalTensor supports + # init_weights(). In local tensor mode, skip training/checkpointing as the + # model is not fully initialized + if config.comm.local_tensor_mode: + logger.info("Local tensor mode enabled - skipping training execution") + return + if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1 From ec6a36bd4e0a4f3415be345abbbf84412c11b689 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 13:35:33 -0800 Subject: [PATCH 2/9] Update (base update) [ghstack-poisoned] From 7bfd2102fb0bf0405cf6dc9d6122b9cf9ae7688a Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 13:36:21 -0800 Subject: [PATCH 3/9] Update [ghstack-poisoned] --- torchtitan/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtitan/train.py b/torchtitan/train.py index ec1ab5b008..90a0b2f2a3 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -12,9 +12,9 @@ import torch -import torchtitan.protocols.train_spec as train_spec_module - from torch.distributed.elastic.multiprocessing.errors import record + +import torchtitan.protocols.train_spec as train_spec_module from torchtitan.components.checkpoint import CheckpointManager from torchtitan.components.dataloader import DataloaderExhaustedError from torchtitan.components.ft import FTManager, maybe_semi_sync_training From 2f0a6e29d8f2a2ff4abf49583cc1bf17b4d9eb93 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Tue, 18 Nov 2025 16:30:02 -0800 Subject: [PATCH 4/9] Update [ghstack-poisoned] --- run_train.sh | 2 +- torchtitan/config/job_config.py | 8 +++++++- torchtitan/distributed/utils.py | 11 ++++------- torchtitan/train.py | 8 +++++++- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/run_train.sh b/run_train.sh index 019ae2d7af..59f9a8f960 100755 --- a/run_train.sh +++ b/run_train.sh @@ -22,7 +22,7 @@ TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} if [ "$DRY_RUN" = "1" ]; then # Dry run mode: validate configuration without GPU/distributed setup echo "Running in DRY RUN mode - configuration validation only" - NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.local_tensor_mode + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.fake_backend --training.steps=1 else # Normal training with torchrun PYTORCH_ALLOC_CONF="expandable_segments:True" \ diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4c01e1b2b8..4c0e337a8a 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -791,8 +791,14 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" + fake_backend: bool = False + """Fake comm backend for dry run mode only""" + local_tensor_mode: bool = False - """Local tensor mode, for debugging purposes. This is an experimental feature.""" + """ + Local tensor mode, for debugging purposes. This is an experimental feature. + fake_backend should be set to True as well if local_tensor_mode is True. + """ @dataclass diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index c8cfd7bc25..a279a4fe4c 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -14,7 +14,6 @@ import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d from torch import distributed as dist -from torch.distributed import _local_tensor from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor @@ -259,8 +258,8 @@ def maybe_enable_amp( ) -def init_local_tensor_mode(world_size: int) -> int: - """Initialize local tensor mode for debugging purposes. +def init_fake_mode(world_size: int) -> int: + """Initialize fake backend Args: world_size: The number of GPUs to simulate @@ -273,8 +272,6 @@ def init_local_tensor_mode(world_size: int) -> int: rank=0, world_size=world_size, ) - lm = _local_tensor.LocalTensorMode(world_size) - lm.__enter__() return world_size @@ -284,7 +281,7 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: - if comm_config.local_tensor_mode: + if comm_config.fake_backend: ngpu_str = os.environ.get("NGPU") if ngpu_str is None: raise ValueError( @@ -296,7 +293,7 @@ def init_distributed( raise ValueError( f"NGPU environment variable must be a valid integer, got: {ngpu_str}" ) from e - return init_local_tensor_mode(world_size) + return init_fake_mode(world_size) def _warn_overwrite_env(env, val): if env in os.environ: diff --git a/torchtitan/train.py b/torchtitan/train.py index 90a0b2f2a3..33302d939f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,6 +11,7 @@ from typing import Any, Generator, Iterable import torch +from torch.distributed import _local_tensor from torch.distributed.elastic.multiprocessing.errors import record @@ -372,8 +373,13 @@ def init_distributed(self) -> ParallelDims: base_folder=job_config.job.dump_folder, ) - parallelism_config = job_config.parallelism + if job_config.comm.local_tensor_mode: + if not job_config.comm.fake_backend: + raise ValueError("LocalTensor can only be used with fake backend.") + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + parallelism_config = job_config.parallelism return ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, From 7e121a693c3bca25394ed81c940b372fe2746bc8 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Wed, 19 Nov 2025 18:16:08 -0800 Subject: [PATCH 5/9] Update [ghstack-poisoned] --- torchtitan/config/job_config.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 4c0e337a8a..04256de4a3 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -796,8 +796,17 @@ class Comm: local_tensor_mode: bool = False """ - Local tensor mode, for debugging purposes. This is an experimental feature. - fake_backend should be set to True as well if local_tensor_mode is True. + Local tensor mode for debugging purposes. There will be only one process + regardless of the number of GPUs. LocalTensor will simulate the + computation by running one rank after another. While the performance will + be slow, the numerics should be the same. This enables us to verify + numerics with fewer GPUs. For example, we can directly run 5D + parallelisms within a single node to reduce the combinations we need to + use in integration tests. + + NOTE: This is an experimental feature. + + NOTE: fake_backend should be set to True when local_tensor_mode is True. """ From 5983b61c43141a0c009ce57033672bc62c6a015c Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 20 Nov 2025 10:48:41 -0800 Subject: [PATCH 6/9] Update [ghstack-poisoned] --- run_train.sh | 14 ++++++++------ torchtitan/config/job_config.py | 24 +++++++++++------------- torchtitan/distributed/utils.py | 18 ++++++++++++++---- torchtitan/train.py | 21 +++++++-------------- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/run_train.sh b/run_train.sh index 59f9a8f960..b24e9047d3 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,19 +10,21 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh -# DRY_RUN=1 ./run_train.sh # for config validation without GPU +# COMM_MODE="fake_backend" ./run_train.sh # for config validation without GPU +# COMM_MODE="local_tensor" ./run_train.sh # for local tensor debugging mode NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} -DRY_RUN=${DRY_RUN:-0} +# COMM_MODE options: "fake_backend" (dry run), "local_tensor" (debug mode), or empty for normal training +COMM_MODE=${COMM_MODE:-""} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -if [ "$DRY_RUN" = "1" ]; then - # Dry run mode: validate configuration without GPU/distributed setup - echo "Running in DRY RUN mode - configuration validation only" - NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.fake_backend --training.steps=1 +if [ -n "$COMM_MODE" ]; then + # Communication mode specified: validate configuration or run in debug mode + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.comm_mode=${COMM_MODE} --training.steps=1 else # Normal training with torchrun PYTORCH_ALLOC_CONF="expandable_segments:True" \ diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 04256de4a3..b9d600b164 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -791,22 +791,20 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" - fake_backend: bool = False - """Fake comm backend for dry run mode only""" - - local_tensor_mode: bool = False + comm_mode: Literal["default", "fake_backend", "local_tensor"] = "default" """ - Local tensor mode for debugging purposes. There will be only one process - regardless of the number of GPUs. LocalTensor will simulate the - computation by running one rank after another. While the performance will - be slow, the numerics should be the same. This enables us to verify - numerics with fewer GPUs. For example, we can directly run 5D - parallelisms within a single node to reduce the combinations we need to - use in integration tests. + Communication mode for distributed training. - NOTE: This is an experimental feature. + Options: + - "default": Normal distributed training with real communication + - "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU) + - "local_tensor": Local tensor mode for debugging purposes. There will be only one process + regardless of the number of GPUs. LocalTensor will simulate the computation by running one + rank after another. While the performance will be slow, the numerics should be the same. + This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D + parallelisms within a single node to reduce the combinations we need to use in integration tests. - NOTE: fake_backend should be set to True when local_tensor_mode is True. + NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally. """ diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index a279a4fe4c..6ed58bf875 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -258,11 +258,12 @@ def maybe_enable_amp( ) -def init_fake_mode(world_size: int) -> int: +def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): """Initialize fake backend Args: world_size: The number of GPUs to simulate + comm_mode: Communication mode ("fake_backend" or "local_tensor") Returns: The world size @@ -272,6 +273,14 @@ def init_fake_mode(world_size: int) -> int: rank=0, world_size=world_size, ) + + # If local_tensor mode is enabled, initialize LocalTensorMode context + if comm_mode == "local_tensor": + from torch.distributed import _local_tensor + + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + return world_size @@ -281,11 +290,11 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: - if comm_config.fake_backend: + if comm_config.comm_mode in ("fake_backend", "local_tensor"): ngpu_str = os.environ.get("NGPU") if ngpu_str is None: raise ValueError( - "NGPU environment variable must be set when using local_tensor_mode" + f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}" ) try: world_size = int(ngpu_str) @@ -293,7 +302,8 @@ def init_distributed( raise ValueError( f"NGPU environment variable must be a valid integer, got: {ngpu_str}" ) from e - return init_fake_mode(world_size) + init_fake_mode(world_size, comm_config.comm_mode) + return world_size def _warn_overwrite_env(env, val): if env in os.environ: diff --git a/torchtitan/train.py b/torchtitan/train.py index 33302d939f..dcbabf435c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -11,7 +11,6 @@ from typing import Any, Generator, Iterable import torch -from torch.distributed import _local_tensor from torch.distributed.elastic.multiprocessing.errors import record @@ -209,10 +208,10 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) - # TODO(local_tensor): Remove this early return once LocalTensor supports - # init_weights().Currently skipping parallelism setup and model initialization - # in local tensor mode. - if job_config.comm.local_tensor_mode: + # TODO(local_tensor): Remove this special case once LocalTensor supports + # init_weights(). Currently it fails occasionally. + if job_config.comm.comm_mode == "local_tensor": + logger.info("Local tensor mode enabled - skipping training execution") return # apply parallelisms and initialization @@ -373,12 +372,6 @@ def init_distributed(self) -> ParallelDims: base_folder=job_config.job.dump_folder, ) - if job_config.comm.local_tensor_mode: - if not job_config.comm.fake_backend: - raise ValueError("LocalTensor can only be used with fake backend.") - lm = _local_tensor.LocalTensorMode(world_size) - lm.__enter__() - parallelism_config = job_config.parallelism return ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, @@ -730,9 +723,9 @@ def main(trainer_class: type[Trainer]) -> None: trainer = trainer_class(config) # TODO(local_tensor): Remove this special case once LocalTensor supports - # init_weights(). In local tensor mode, skip training/checkpointing as the - # model is not fully initialized - if config.comm.local_tensor_mode: + # init_weights() and foreach_allgather. In local tensor mode, skip + # training/checkpointing as the # model is not fully initialized + if config.comm.comm_mode == "local_tensor": logger.info("Local tensor mode enabled - skipping training execution") return From 6f2e7f499d9904faf6dc0865ff6b292d989addc7 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Thu, 20 Nov 2025 13:48:46 -0800 Subject: [PATCH 7/9] Update [ghstack-poisoned] --- torchtitan/distributed/utils.py | 3 +++ torchtitan/train.py | 6 ------ 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 6ed58bf875..664d23aed2 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -281,6 +281,9 @@ def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): lm = _local_tensor.LocalTensorMode(world_size) lm.__enter__() + # TODO: remove this once the root cause is figured out + torch.manual_seed(42) + return world_size diff --git a/torchtitan/train.py b/torchtitan/train.py index dcbabf435c..1914429398 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -208,12 +208,6 @@ def __init__(self, job_config: JobConfig): self.loss_fn, self.gradient_accumulation_steps ) - # TODO(local_tensor): Remove this special case once LocalTensor supports - # init_weights(). Currently it fails occasionally. - if job_config.comm.comm_mode == "local_tensor": - logger.info("Local tensor mode enabled - skipping training execution") - return - # apply parallelisms and initialization if parallel_dims.pp_enabled: if not self.train_spec.pipelining_fn: From d560e50c99fb3f948c77689d7755d10b4f4129a8 Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 24 Nov 2025 13:07:57 -0800 Subject: [PATCH 8/9] Update [ghstack-poisoned] --- torchtitan/distributed/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 7fb46a4cca..86fd17c8dc 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -281,8 +281,6 @@ def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): lm = _local_tensor.LocalTensorMode(world_size) lm.__enter__() - return world_size - def init_distributed( comm_config: CommConfig, @@ -290,11 +288,11 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: - if comm_config.mode in ("fake_backend", "local_tensor"): + if comm_config.comm_mode in ("fake_backend", "local_tensor"): ngpu_str = os.environ.get("NGPU") if ngpu_str is None: raise ValueError( - f"NGPU environment variable must be set when using comm_mode={comm_config.mode}" + f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}" ) try: world_size = int(ngpu_str) @@ -302,7 +300,7 @@ def init_distributed( raise ValueError( f"NGPU environment variable must be a valid integer, got: {ngpu_str}" ) from e - init_fake_mode(world_size, comm_config.mode) + init_fake_mode(world_size, comm_config.comm_mode) return world_size def _warn_overwrite_env(env, val): From 73799490935787dbed3c4219e81aa5b7c9d79b9e Mon Sep 17 00:00:00 2001 From: Chien-Chin Huang Date: Mon, 24 Nov 2025 15:03:33 -0800 Subject: [PATCH 9/9] Update [ghstack-poisoned] --- torchtitan/distributed/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 86fd17c8dc..6a73ffd083 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -288,11 +288,11 @@ def init_distributed( base_folder: str = "", ranks: list[int] | None = None, ) -> int: - if comm_config.comm_mode in ("fake_backend", "local_tensor"): + if comm_config.mode in ("fake_backend", "local_tensor"): ngpu_str = os.environ.get("NGPU") if ngpu_str is None: raise ValueError( - f"NGPU environment variable must be set when using comm_mode={comm_config.comm_mode}" + f"NGPU environment variable must be set when using comm_mode={comm_config.mode}" ) try: world_size = int(ngpu_str) @@ -300,7 +300,7 @@ def init_distributed( raise ValueError( f"NGPU environment variable must be a valid integer, got: {ngpu_str}" ) from e - init_fake_mode(world_size, comm_config.comm_mode) + init_fake_mode(world_size, comm_config.mode) return world_size def _warn_overwrite_env(env, val):