diff --git a/build/builder.py b/build/builder.py index 409013ceb..bd3ef5f4a 100644 --- a/build/builder.py +++ b/build/builder.py @@ -21,6 +21,7 @@ from build.model import Transformer from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype +from distributed import parallelize_llama, ParallelDims @dataclass @@ -36,7 +37,7 @@ class BuilderArgs: device: Optional[str] = None precision: torch.dtype = torch.float32 setup_caches: bool = False - use_tp: bool = False + use_distributed: bool = False is_chat_model: bool = False prefill_possible: bool = False @@ -141,7 +142,7 @@ def from_args(cls, args): # -> BuilderArgs: device=args.device, precision=dtype, setup_caches=(args.output_dso_path or args.output_pte_path), - use_tp=False, + use_distributed=args.distributed, is_chat_model=is_chat_model, ) @@ -346,11 +347,22 @@ def _load_model(builder_args, only_config=False): else: model = _load_model_default(builder_args) - if builder_args.use_tp: - from tp import apply_tp + # TODO: ongoing work to support loading model from checkpoint + if builder_args.use_distributed: + # init distributed + world_size = int(os.environ["WORLD_SIZE"]) + # TODO: To make tp, pp degree configurable + parallel_dims = ParallelDims( + tp=8, + pp=1, + world_size=world_size, + ) + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) + init_distributed(job_config) - print("Applying tensor parallel to model ...") - apply_tp(model) + print("Applying model parallel to model ...") + parallelize_llama(model) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() diff --git a/cli.py b/cli.py index 24f6d6ed0..3c8a503d7 100644 --- a/cli.py +++ b/cli.py @@ -56,6 +56,11 @@ def add_arguments_for_verb(parser, verb: str): action="store_true", help="Whether to start an interactive chat session", ) + parser.add_argument( + "--distributed", + action="store_true", + help="Whether to enable distributed inference", + ) parser.add_argument( "--gui", action="store_true", diff --git a/distributed/__init__.py b/distributed/__init__.py new file mode 100644 index 000000000..64cd5f22d --- /dev/null +++ b/distributed/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from distributed.parallelize_llama import parallelize_llama +from distributed.parallel_config import ParallelDims diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py new file mode 100644 index 000000000..d1d8aa9c7 --- /dev/null +++ b/distributed/parallel_config.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field +from torch.distributed.device_mesh import init_device_mesh + +@dataclass +class ParallelDims: + tp: int + pp: int + world_size: int + + def __post_init__(self): + self._validate() + + def _validate(self): + tp, pp = self.tp, self.pp + assert tp >= 1, tp + assert pp >= 1, pp + assert ( + tp * pp == self.world_size + ), f"Invalid parallel dims: tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})" + + def build_mesh(self, device_type): + dims = [] + names = [] + for d, name in zip( + [self.pp, self.tp], ["pp", "tp"], strict=True + ): + if d > 1: + dims.append(d) + names.append(name) + logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") + names = tuple(names) + return init_device_mesh(device_type, dims, mesh_dim_names=names) + + @property + def tp_enabled(self): + return self.tp > 1 + + @property + def pp_enabled(self): + return self.pp > 1 diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py new file mode 100644 index 000000000..e2b73d0dd --- /dev/null +++ b/distributed/parallelize_llama.py @@ -0,0 +1,130 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Tuple +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) + +import torch.nn as nn +from distributed.parallel_config import ParallelDims +from torch.distributed.device_mesh import DeviceMesh + + +def apply_tp( + model: nn.Module, + world_mesh: DeviceMesh, +) -> nn.Module: + """ + Apply tensor parallelism to the given model. More details can be + found in https://pytorch.org/tutorials/intermediate/TP_tutorial.html. + + NOTE: The way we apply tp is based on the assumption that the model is a LLaMA model. + One needs to change the ``parallelize_plan`` we pass in to the TP api if the model + is not a LLaMA model. + + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + world_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + Return: + A :class:`nn.Module` object tensor-parallelized. + """ + + tp_mesh = world_mesh["tp"] + + # 1. Parallelize the first embedding and the last linear proj layer + # 2. Parallelize the root norm layer over the sequence dim + # 3. Shard the first transformer block's inputs + model = parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + use_local_output=True, + ), + "norm": SequenceParallel(), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + for layer_id, transformer_block in model.layers.items(): + layer_plan = { + "attention": prepare_module_input( + input_layouts=(Shard(1), None), + desired_input_layouts=(Replicate(), None), + ), + "attention.wq": ColwiseParallel(), + "attention.wk": ColwiseParallel(), + "attention.wv": ColwiseParallel(), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "attention_norm": SequenceParallel(), + "feed_forward": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(), + "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w3": ColwiseParallel(), + "ffn_norm": SequenceParallel(), + } + + # Adjust attention module to use the local number of heads + attn_layer = transformer_block.attention + attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() + attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() + attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info("Applied Tensor Parallelism to the model") + return model + + +def parallelize_llama( + model: nn.Module, + world_mesh: DeviceMesh, + parallel_dims: ParallelDims, +) -> nn.Module: + """ + Apply tensor parallelism and other parallelism(TODO) to the model for inference. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + world_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + parallel_dims (:class:`ParallelDims`): + The object of the util class which contains the degree for each parallelism. + Return: + A :class:`nn.Module` object parallelized. + """ + + if parallel_dims.tp_enabled: + model = apply_tp(model, world_mesh, parallel_dims) + + return model diff --git a/distributed/run_dist_inference.sh b/distributed/run_dist_inference.sh new file mode 100755 index 000000000..3268750d9 --- /dev/null +++ b/distributed/run_dist_inference.sh @@ -0,0 +1,31 @@ +#!/usr/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +set -ex + +# libUV is a scalable backend for TCPStore which is used in processGroup +# rendezvous. This is the recommended backend for distributed training. +export USE_LIBUV=1 + +# use envs as local overrides for convenience +# e.g. +# LOG_RANK=0,1 NGPU=4 ./run_dist_inference.sh + +NGPU=${NGPU:-"8"} + +# TODO: We need to decide how to log for inference. +# by default log just rank 0 output, +LOG_RANK=${LOG_RANK:-0} + +overrides="" +if [ $# -ne 0 ]; then + overrides="$*" +fi + +torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ +--local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ +torchchat.py chat llama3 --distributed $overrides diff --git a/distributed/utils.py b/distributed/utils.py new file mode 100644 index 000000000..71b68f94a --- /dev/null +++ b/distributed/utils.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from datetime import timedelta + +import torch + + +def _warn_overwrite_env(env, val): + if env in os.environ: + logger.warning( + f"ENV[{env}] = {os.environ[env]} will be overridden to {val} based on job config" + ) + os.environ[env] = val + + +TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" +TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" +DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" +ASYNC_ERROR_HANDLING = "TORCH_NCCL_ASYNC_ERROR_HANDLING" +SKIP_CLEANUP = "3" + + +def init_distributed(job_config): + # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) + # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 + # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle + # behavior differences + _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) + + # enable torch nccl flight recorder in the mode that would dump files if timeout is detected + _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) + if job_config.comm.trace_buf_size > 0: + # dump on timeout by default if trace buffer is enabled + _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") + dump_dir = f"{job_config.job.dump_folder}/comm_trace" + os.makedirs(dump_dir, exist_ok=True) + _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") + + torch.distributed.init_process_group( + "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + ) + + # to mitigate the memory issue that collectives using + # async_op=True hold memory longer than they should + # such as those in tensor parallelism + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"