diff --git a/torchtrain/config_manager.py b/torchtrain/config_manager.py index 48750f7031..8add1c45e0 100644 --- a/torchtrain/config_manager.py +++ b/torchtrain/config_manager.py @@ -263,4 +263,18 @@ def init_args_from_command_line( default="2", # 2 = checkpoint every other layer help="['int', 'op'] = selective activation checkpointing options, 'int' for every nth layer, or 'op' for op level ac.", ) + + # communications library settings + parser.add_argument( + "--comm.timeout_seconds", + type=int, + default=5, + help="Timeout for async communication operations", + ) + parser.add_argument( + "--comm.trace_buf_size", + type=int, + default=20000, + help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled", + ) return parser.parse_args(args_list) diff --git a/torchtrain/parallelisms/__init__.py b/torchtrain/parallelisms/__init__.py index 1226a9c7f7..1c1599fdd8 100644 --- a/torchtrain/parallelisms/__init__.py +++ b/torchtrain/parallelisms/__init__.py @@ -1,9 +1,14 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. +import os + from dataclasses import dataclass +from datetime import timedelta from functools import cached_property +import torch + from torch.distributed.device_mesh import init_device_mesh from torchtrain.logging_utils import logger from torchtrain.parallelisms.parallelize_llama import parallelize_llama @@ -12,6 +17,41 @@ "llama": parallelize_llama, } +TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" +TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" +DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" +ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" +SKIP_CLEANUP = "3" + + +def _warn_overwrite_env(env, val): + if env in os.environ: + logger.warning( + f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" + ) + os.environ[env] = val + + +def init_distributed(job_config): + # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) + # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 + # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle + # behavior differences + _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) + + # enable torch nccl flight recorder in the mode that would dump files if timeout is detected + _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) + if job_config.comm.trace_buf_size > 0: + # dump on timeout by default if trace buffer is enabled + _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") + dump_dir = f"{job_config.job.dump_folder}/comm_trace" + os.makedirs(dump_dir, exist_ok=True) + _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") + + torch.distributed.init_process_group( + "nccl", timeout=timedelta(seconds=job_config.comm.timeout_seconds) + ) + @dataclass class ParallelDims: diff --git a/train.py b/train.py index 3f37a61b05..ec0617611c 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,11 @@ from torchtrain.meta_init import meta_model_init from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config -from torchtrain.parallelisms import models_parallelize_fns, ParallelDims +from torchtrain.parallelisms import ( + init_distributed, + models_parallelize_fns, + ParallelDims, +) from torchtrain.profiling import maybe_run_profiler from torchtrain.utils import Color, dist_max, dist_mean @@ -100,6 +104,9 @@ def main(job_config: JobConfig): world_size=world_size, enable_loss_parallel=job_config.training.enable_loss_parallel, ) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + init_distributed(job_config) + world_mesh = parallel_dims.build_mesh(device_type="cuda") model_name = job_config.model.name