diff --git a/build/builder.py b/build/builder.py index bd3ef5f4a..d2ee10ea9 100644 --- a/build/builder.py +++ b/build/builder.py @@ -21,7 +21,7 @@ from build.model import Transformer from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype -from distributed import parallelize_llama, ParallelDims +from distributed import parallelize_llama, ParallelDims, init_distributed @dataclass @@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args): builder_args.gguf_kwargs = None +def _init_model_on_meta_device(builder_args): + with torch.device("meta"): + if builder_args.params_path: + return Transformer.from_params(builder_args.params_path) + elif builder_args.params_table: + return Transformer.from_table(builder_args.params_table) + else: + return Transformer.from_name(builder_args.checkpoint_path.parent.name) + def _load_model_gguf(builder_args, only_config=False): assert builder_args.gguf_path if builder_args.gguf_kwargs is None: @@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False): def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path - with torch.device("meta"): - if builder_args.params_path: - model = Transformer.from_params(builder_args.params_path) - elif builder_args.params_table: - model = Transformer.from_table(builder_args.params_table) - else: - model = Transformer.from_name(builder_args.checkpoint_path.parent.name) - + model = _init_model_on_meta_device(builder_args) # checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) cps = [] if builder_args.checkpoint_dir is not None: @@ -357,12 +359,11 @@ def _load_model(builder_args, only_config=False): pp=1, world_size=world_size, ) - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - torch.cuda.set_device(device) - init_distributed(job_config) + init_distributed() + world_mesh = parallel_dims.build_mesh(device_type="cuda") print("Applying model parallel to model ...") - parallelize_llama(model) + parallelize_llama(model, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() diff --git a/distributed/__init__.py b/distributed/__init__.py index 64cd5f22d..2c5417404 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -6,3 +6,4 @@ from distributed.parallelize_llama import parallelize_llama from distributed.parallel_config import ParallelDims +from distributed.utils import init_distributed diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py index d1d8aa9c7..048d4809c 100644 --- a/distributed/parallel_config.py +++ b/distributed/parallel_config.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from torch.distributed.device_mesh import init_device_mesh +from distributed.utils import logger @dataclass class ParallelDims: diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index e2b73d0dd..d1cf8fd80 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -10,12 +10,13 @@ parallelize_module, PrepareModuleInput, RowwiseParallel, - SequenceParallel, ) import torch.nn as nn +from torch.distributed._tensor import Replicate, Shard from distributed.parallel_config import ParallelDims from torch.distributed.device_mesh import DeviceMesh +from distributed.utils import logger def apply_tp( @@ -43,53 +44,55 @@ def apply_tp( tp_mesh = world_mesh["tp"] - # 1. Parallelize the first embedding and the last linear proj layer - # 2. Parallelize the root norm layer over the sequence dim - # 3. Shard the first transformer block's inputs - model = parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Replicate(), - use_local_output=True, - ), - "norm": SequenceParallel(), - }, - ) - - # Apply tensor + sequence parallelism to every transformer block - for layer_id, transformer_block in model.layers.items(): + # TODO: To figure out the TP for the tok_embedding and the linear proj layer. + # # 1. Parallelize the first embedding and the last linear proj layer + # # 2. Shard the first transformer block's inputs + # model = parallelize_module( + # model, + # tp_mesh, + # { + # "tok_embeddings": RowwiseParallel( + # input_layouts=Replicate(), + # output_layouts=Replicate(), + # ), + # "output": ColwiseParallel( + # input_layouts=Shard(1), + # output_layouts=Replicate(), + # use_local_output=True, + # ), + # }, + # ) + + # Apply tensor parallelism to every transformer block + for transformer_block in model.layers: layer_plan = { - "attention": prepare_module_input( - input_layouts=(Shard(1), None), + "attention": PrepareModuleInput( + input_layouts=(Replicate(), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": ColwiseParallel(), "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), - "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), + "attention.wo": RowwiseParallel( + output_layouts=Replicate(), + use_local_output=True, + ), + "feed_forward": PrepareModuleInput( + input_layouts=(Replicate(),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w2": RowwiseParallel( + output_layouts=Replicate(), + use_local_output=True + ), "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads attn_layer = transformer_block.attention attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() parallelize_module( module=transformer_block, @@ -125,6 +128,6 @@ def parallelize_llama( """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims) + model = apply_tp(model, world_mesh) return model diff --git a/distributed/utils.py b/distributed/utils.py index 71b68f94a..9a02dd4cd 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -8,6 +8,8 @@ from datetime import timedelta import torch +import logging +logger = logging.getLogger() def _warn_overwrite_env(env, val): @@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val): SKIP_CLEANUP = "3" -def init_distributed(job_config): +def init_distributed(init_timeout_seconds: int = 120): # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle # behavior differences _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) - # enable torch nccl flight recorder in the mode that would dump files if timeout is detected - _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) - if job_config.comm.trace_buf_size > 0: - # dump on timeout by default if trace buffer is enabled - _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") - dump_dir = f"{job_config.job.dump_folder}/comm_trace" - os.makedirs(dump_dir, exist_ok=True) - _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") - torch.distributed.init_process_group( - "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + "nccl", timeout=timedelta(seconds=init_timeout_seconds) ) # to mitigate the memory issue that collectives using diff --git a/generate.py b/generate.py index 3e042f6b1..acca657b7 100644 --- a/generate.py +++ b/generate.py @@ -8,6 +8,7 @@ import logging import sys import time +import os from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple @@ -504,6 +505,12 @@ def _main( # print = lambda *args, **kwargs: None print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}") + # If using distributed inference we cannot just assign device to be cuda + # because it will be assigned to cuda:0 by default. We need explicitely set + # the device to be the local rank. + if builder_args.use_distributed: + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) set_precision(builder_args.precision) is_speculative = speculative_builder_args.checkpoint_path is not None