-
Notifications
You must be signed in to change notification settings - Fork 569
Shorten nccl comm timeout and enable flight recorder dumping #103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,14 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. | ||
|
||
import os | ||
|
||
from dataclasses import dataclass | ||
from datetime import timedelta | ||
from functools import cached_property | ||
|
||
import torch | ||
|
||
from torch.distributed.device_mesh import init_device_mesh | ||
from torchtrain.logging_utils import logger | ||
from torchtrain.parallelisms.parallelize_llama import parallelize_llama | ||
|
@@ -12,6 +17,41 @@ | |
"llama": parallelize_llama, | ||
} | ||
|
||
TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" | ||
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" | ||
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" | ||
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" | ||
SKIP_CLEANUP = "3" | ||
|
||
|
||
def _warn_overwrite_env(env, val): | ||
if env in os.environ: | ||
logger.warning( | ||
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" | ||
) | ||
os.environ[env] = val | ||
|
||
|
||
def init_distributed(job_config): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My feeling is that this "parallelisms" folder is not only about "parallelisms" itself now, maybe we should rename this folder to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. wed have torchtrain.distributed and torch.distributed, is that good or bad? i don't have a strong preference. i could also see an argument for making something like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🤔 yeah it might cause some confusions with torch.distributed.. I am fine with keeping it as it for now |
||
# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) | ||
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 | ||
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle | ||
# behavior differences | ||
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) | ||
|
||
# enable torch nccl flight recorder in the mode that would dump files if timeout is detected | ||
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) | ||
if job_config.comm.trace_buf_size > 0: | ||
# dump on timeout by default if trace buffer is enabled | ||
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1") | ||
dump_dir = f"{job_config.job.dump_folder}/comm_trace" | ||
os.makedirs(dump_dir, exist_ok=True) | ||
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") | ||
|
||
torch.distributed.init_process_group( | ||
"nccl", timeout=timedelta(seconds=job_config.comm.timeout_seconds) | ||
) | ||
|
||
|
||
@dataclass | ||
class ParallelDims: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,11 @@ | |
from torchtrain.meta_init import meta_model_init | ||
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor | ||
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config | ||
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims | ||
from torchtrain.parallelisms import ( | ||
init_distributed, | ||
models_parallelize_fns, | ||
ParallelDims, | ||
) | ||
from torchtrain.profiling import maybe_run_profiler | ||
from torchtrain.utils import Color, dist_max, dist_mean | ||
|
||
|
@@ -100,6 +104,9 @@ def main(job_config: JobConfig): | |
world_size=world_size, | ||
enable_loss_parallel=job_config.training.enable_loss_parallel, | ||
) | ||
torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this set device call be part of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i thought about that. maybe, but otoh it is kind of an important detail that might be worth not hiding from train script. (if we think of train script as showing the important steps and teaching users) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i might leave it for now, and hope instead of hiding it inside init_distributed we can actually change init_process_group to take device and then not require set default device by user. cc @kwen2501 did you make this change already, or part of it, for enabling eager init? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good! passing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. actually i think what we want is to do eager init in large scale, but also have init be overlapped with other python stuff by making initialization happen in a non-blocking way. @kwen2501 has some plans for how to do that inside PGNccl. doing init lazily is a bit less predictable and doing it early but async could overlap it with other things like setting up dataloader, loading model weights, etc. |
||
init_distributed(job_config) | ||
|
||
world_mesh = parallel_dims.build_mesh(device_type="cuda") | ||
|
||
model_name = job_config.model.name | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
dump_trace_buf_size
to avoid confusion about profiling traces?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah.. comm.trace is probably more clear than dump_trace if you want to avoid confusion with profile_trace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ohh so there's a
comm
prefix there, yeah that would work