Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions torchtrain/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,4 +263,18 @@ def init_args_from_command_line(
default="2", # 2 = checkpoint every other layer
help="['int', 'op'] = selective activation checkpointing options, 'int' for every nth layer, or 'op' for op level ac.",
)

# communications library settings
parser.add_argument(
"--comm.timeout_seconds",
type=int,
default=5,
help="Timeout for async communication operations",
)
parser.add_argument(
"--comm.trace_buf_size",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: dump_trace_buf_size to avoid confusion about profiling traces?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah.. comm.trace is probably more clear than dump_trace if you want to avoid confusion with profile_trace?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh so there's a comm prefix there, yeah that would work

type=int,
default=20000,
help="Flight recorder ring buffer size, >0 means recording by default, 0 means disabled",
)
return parser.parse_args(args_list)
40 changes: 40 additions & 0 deletions torchtrain/parallelisms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

import os

from dataclasses import dataclass
from datetime import timedelta
from functools import cached_property

import torch

from torch.distributed.device_mesh import init_device_mesh
from torchtrain.logging_utils import logger
from torchtrain.parallelisms.parallelize_llama import parallelize_llama
Expand All @@ -12,6 +17,41 @@
"llama": parallelize_llama,
}

TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE"
TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE"
DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT"
ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING"
SKIP_CLEANUP = "3"


def _warn_overwrite_env(env, val):
if env in os.environ:
logger.warning(
f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config"
)
os.environ[env] = val


def init_distributed(job_config):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My feeling is that this "parallelisms" folder is not only about "parallelisms" itself now, maybe we should rename this folder to distributed? wdyt?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wed have torchtrain.distributed and torch.distributed, is that good or bad? i don't have a strong preference.

i could also see an argument for making something like torchtrain.utils .distributed or .dist_utils and then keep parallelisms for the things you originally intended. up to you

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤔 yeah it might cause some confusions with torch.distributed.. I am fine with keeping it as it for now

# FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup)
# to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055
# This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle
# behavior differences
_warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP)

# enable torch nccl flight recorder in the mode that would dump files if timeout is detected
_warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size))
if job_config.comm.trace_buf_size > 0:
# dump on timeout by default if trace buffer is enabled
_warn_overwrite_env(DUMP_ON_TIMEOUT, "1")
dump_dir = f"{job_config.job.dump_folder}/comm_trace"
os.makedirs(dump_dir, exist_ok=True)
_warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_")

torch.distributed.init_process_group(
"nccl", timeout=timedelta(seconds=job_config.comm.timeout_seconds)
)


@dataclass
class ParallelDims:
Expand Down
9 changes: 8 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
from torchtrain.meta_init import meta_model_init
from torchtrain.metrics import build_metric_logger, get_num_params, GPUMemoryMonitor
from torchtrain.models import model_name_to_cls, model_name_to_tokenizer, models_config
from torchtrain.parallelisms import models_parallelize_fns, ParallelDims
from torchtrain.parallelisms import (
init_distributed,
models_parallelize_fns,
ParallelDims,
)
from torchtrain.profiling import maybe_run_profiler
from torchtrain.utils import Color, dist_max, dist_mean

Expand Down Expand Up @@ -100,6 +104,9 @@ def main(job_config: JobConfig):
world_size=world_size,
enable_loss_parallel=job_config.training.enable_loss_parallel,
)
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this set device call be part of init_distributed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought about that. maybe, but otoh it is kind of an important detail that might be worth not hiding from train script. (if we think of train script as showing the important steps and teaching users)

Copy link
Contributor Author

@wconstab wconstab Mar 15, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i might leave it for now, and hope instead of hiding it inside init_distributed we can actually change init_process_group to take device and then not require set default device by user.

cc @kwen2501 did you make this change already, or part of it, for enabling eager init?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good! passing device_id to init_process_group might trigger eager init of global pg, which I think we might want to avoid in large scale? i.e. we should only eager init sub group nccl?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually i think what we want is to do eager init in large scale, but also have init be overlapped with other python stuff by making initialization happen in a non-blocking way. @kwen2501 has some plans for how to do that inside PGNccl.

doing init lazily is a bit less predictable and doing it early but async could overlap it with other things like setting up dataloader, loading model weights, etc.

init_distributed(job_config)

world_mesh = parallel_dims.build_mesh(device_type="cuda")

model_name = job_config.model.name
Expand Down