From d700a8e6cc7422410b4f5ceafbcdaad966f01b2c Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 27 Jun 2024 21:35:59 -0700 Subject: [PATCH 1/7] [Dist][Inference] U-haul TP and distribute utils code to TorchChat --- distributed/utils.py | 143 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 143 insertions(+) diff --git a/distributed/utils.py b/distributed/utils.py index 71b68f94a..b792f7daf 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -5,9 +5,32 @@ # LICENSE file in the root directory of this source tree. import os +<<<<<<< HEAD from datetime import timedelta import torch +======= +from dataclasses import dataclass +from datetime import timedelta +from typing import Union + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch.distributed.device_mesh import DeviceMesh +from torchtitan.logging_utils import logger +from torchtitan.parallelisms import ParallelDims + + +def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) + + +def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: + tensor = torch.tensor(x).cuda() + return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) +>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) def _warn_overwrite_env(env, val): @@ -18,6 +41,54 @@ def _warn_overwrite_env(env, val): os.environ[env] = val +<<<<<<< HEAD +======= +def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: + """ + Returns global rank 0 in non-pipeline-parallel configs, and returns the global + rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. + """ + if parallel_dims.pp_enabled: + assert ( + world_mesh.mesh_dim_names[0] == "pp" + ), "get_metrics_rank assumes pp is the outer mesh dim" + pp_mesh = world_mesh["pp"] + pp_size = pp_mesh.size() + metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) + else: + metrics_log_rank = 0 + + return metrics_log_rank + + +def set_pg_timeouts(timeout, world_mesh): + """ + Sets the timeout for all PGs in the provided mesh, and the default (world) group. + + Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase + otherwise you may face a race where the slow rank has not reached the timeout reduction point + yet due to slow operations permitted under the old timeout value, but other faster ranks may + start issueing collectives under the new shorter timeout and then immediately timeout. + """ + logger.info( + f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}" + ) + # Ensure that all the ranks have reached the point of setting the new timeout- + # otherwise, some ranks may issue collectives with the new/shorter timeout and + # those may time out, before other ranks have finished with initialization done + # under the old/slow timeout. + torch.distributed.barrier() + torch.cuda.synchronize() + + groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] + + # None represents the 'default' PG, not part of the mesh + groups.append(None) + for group in groups: + torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) + + +>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" @@ -49,3 +120,75 @@ def init_distributed(job_config): # async_op=True hold memory longer than they should # such as those in tensor parallelism os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" +<<<<<<< HEAD +======= + + +def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: + num_params = sum(p.numel() for p in model.parameters()) + if exclude_embedding: + num_params -= model.tok_embeddings.weight.numel() + return num_params + + +def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: + l, h, q, t = ( + model_config.n_layers, + model_config.n_heads, + model_config.dim // model_config.n_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + flop_per_token = 6 * num_params + 12 * l * h * q * t + + return flop_per_token + + +# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU +def get_peak_flops(device_name: str) -> int: + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 1979e12 + elif "PCIe" in device_name: + return 756e12 + else: # for SXM and other variants + return 989e12 + else: # for other GPU types, assume A100 + return 312e12 + + +@dataclass(frozen=True) +class Color: + black = "\033[30m" + red = "\033[31m" + green = "\033[32m" + yellow = "\033[33m" + blue = "\033[34m" + magenta = "\033[35m" + cyan = "\033[36m" + white = "\033[37m" + reset = "\033[39m" + + +@dataclass(frozen=True) +class NoColor: + black = "" + red = "" + green = "" + yellow = "" + blue = "" + magenta = "" + cyan = "" + white = "" + reset = "" +>>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) From 8b5ac5ce28e4109363563c7d8ccc1b4d22ed70e3 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Thu, 27 Jun 2024 21:58:34 -0700 Subject: [PATCH 2/7] Remove unnecessary code and add comment --- distributed/utils.py | 143 ------------------------------------------- 1 file changed, 143 deletions(-) diff --git a/distributed/utils.py b/distributed/utils.py index b792f7daf..71b68f94a 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -5,32 +5,9 @@ # LICENSE file in the root directory of this source tree. import os -<<<<<<< HEAD from datetime import timedelta import torch -======= -from dataclasses import dataclass -from datetime import timedelta -from typing import Union - -import torch -import torch.distributed._functional_collectives as funcol -import torch.distributed.distributed_c10d as c10d -from torch.distributed.device_mesh import DeviceMesh -from torchtitan.logging_utils import logger -from torchtitan.parallelisms import ParallelDims - - -def dist_max(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.MAX.name, group=mesh) - - -def dist_mean(x: Union[int, float], mesh: DeviceMesh) -> float: - tensor = torch.tensor(x).cuda() - return funcol.all_reduce(tensor, reduceOp=c10d.ReduceOp.AVG.name, group=mesh) ->>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) def _warn_overwrite_env(env, val): @@ -41,54 +18,6 @@ def _warn_overwrite_env(env, val): os.environ[env] = val -<<<<<<< HEAD -======= -def get_metrics_rank(world_mesh: DeviceMesh, parallel_dims: ParallelDims) -> int: - """ - Returns global rank 0 in non-pipeline-parallel configs, and returns the global - rank of the 0th rank in the last pipeline stage when pipeline parallelism is enabled. - """ - if parallel_dims.pp_enabled: - assert ( - world_mesh.mesh_dim_names[0] == "pp" - ), "get_metrics_rank assumes pp is the outer mesh dim" - pp_mesh = world_mesh["pp"] - pp_size = pp_mesh.size() - metrics_log_rank = int((world_mesh.size() // pp_size) * (pp_size - 1)) - else: - metrics_log_rank = 0 - - return metrics_log_rank - - -def set_pg_timeouts(timeout, world_mesh): - """ - Sets the timeout for all PGs in the provided mesh, and the default (world) group. - - Note: synchronizes via a barrier, before changing the timeouts. This is important, becuase - otherwise you may face a race where the slow rank has not reached the timeout reduction point - yet due to slow operations permitted under the old timeout value, but other faster ranks may - start issueing collectives under the new shorter timeout and then immediately timeout. - """ - logger.info( - f"Synchronizing and adjusting timeout for all ProcessGroups to {timeout}" - ) - # Ensure that all the ranks have reached the point of setting the new timeout- - # otherwise, some ranks may issue collectives with the new/shorter timeout and - # those may time out, before other ranks have finished with initialization done - # under the old/slow timeout. - torch.distributed.barrier() - torch.cuda.synchronize() - - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - - # None represents the 'default' PG, not part of the mesh - groups.append(None) - for group in groups: - torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) - - ->>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) TRACE_BUFFER_SIZE = "TORCH_NCCL_TRACE_BUFFER_SIZE" TRACE_FILE = "TORCH_NCCL_DEBUG_INFO_TEMP_FILE" DUMP_ON_TIMEOUT = "TORCH_NCCL_DUMP_ON_TIMEOUT" @@ -120,75 +49,3 @@ def init_distributed(job_config): # async_op=True hold memory longer than they should # such as those in tensor parallelism os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" -<<<<<<< HEAD -======= - - -def get_num_params(model: torch.nn.Module, exclude_embedding: bool = False) -> int: - num_params = sum(p.numel() for p in model.parameters()) - if exclude_embedding: - num_params -= model.tok_embeddings.weight.numel() - return num_params - - -def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int: - l, h, q, t = ( - model_config.n_layers, - model_config.n_heads, - model_config.dim // model_config.n_heads, - seq_len, - ) - # Reasoning behind the factor of 12 for the self-attention part of the formula: - # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) - # 2. the flash attention does 1 more matmul recomputation in the backward - # but recomputation should not be counted in calculating MFU (+0) - # 3. each matmul performs 1 multiplication and 1 addition (*2) - # 4. we follow the convention and do not account for sparsity in causal attention - flop_per_token = 6 * num_params + 12 * l * h * q * t - - return flop_per_token - - -# hardcoded BF16 type peak flops for NVIDIA A100 and H100 GPU -def get_peak_flops(device_name: str) -> int: - if "A100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/a100/ - return 312e12 - elif "H100" in device_name: - # data from https://www.nvidia.com/en-us/data-center/h100/ - # NOTE: Specifications are one-half lower without sparsity. - if "NVL" in device_name: - return 1979e12 - elif "PCIe" in device_name: - return 756e12 - else: # for SXM and other variants - return 989e12 - else: # for other GPU types, assume A100 - return 312e12 - - -@dataclass(frozen=True) -class Color: - black = "\033[30m" - red = "\033[31m" - green = "\033[32m" - yellow = "\033[33m" - blue = "\033[34m" - magenta = "\033[35m" - cyan = "\033[36m" - white = "\033[37m" - reset = "\033[39m" - - -@dataclass(frozen=True) -class NoColor: - black = "" - red = "" - green = "" - yellow = "" - blue = "" - magenta = "" - cyan = "" - white = "" - reset = "" ->>>>>>> 0c3e7bf ([Dist][Inference] U-haul TP and distribute utils code to TorchChat) From f9052af3e0f72ea845471e903afb23f7c724b657 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 10:01:37 -0700 Subject: [PATCH 3/7] Add Torchrun script and enable distributed for that script --- config/model_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/config/model_config.py b/config/model_config.py index aa6f24e79..2e479beb7 100644 --- a/config/model_config.py +++ b/config/model_config.py @@ -3,6 +3,7 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. + import json from dataclasses import dataclass, field from enum import Enum From c2dbd20d5d1e8a44cf7a0c3feffbfb21c1cfc646 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 10:06:11 -0700 Subject: [PATCH 4/7] Remove unnecessary changes --- config/model_config.py | 1 - 1 file changed, 1 deletion(-) diff --git a/config/model_config.py b/config/model_config.py index 2e479beb7..aa6f24e79 100644 --- a/config/model_config.py +++ b/config/model_config.py @@ -3,7 +3,6 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. - import json from dataclasses import dataclass, field from enum import Enum From 7429672621f109661ce8648f85ab2e6fdbe53264 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Mon, 1 Jul 2024 19:48:49 -0700 Subject: [PATCH 5/7] Add checkpoint loading for meta init model --- build/builder.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/build/builder.py b/build/builder.py index bd3ef5f4a..583a1a52f 100644 --- a/build/builder.py +++ b/build/builder.py @@ -278,6 +278,15 @@ def _unset_gguf_kwargs(builder_args): builder_args.gguf_kwargs = None +def _init_model_on_meta_device(builder_args): + with torch.device("meta"): + if builder_args.params_path: + return Transformer.from_params(builder_args.params_path) + elif builder_args.params_table: + return Transformer.from_table(builder_args.params_table) + else: + return Transformer.from_name(builder_args.checkpoint_path.parent.name) + def _load_model_gguf(builder_args, only_config=False): assert builder_args.gguf_path if builder_args.gguf_kwargs is None: @@ -291,14 +300,7 @@ def _load_model_gguf(builder_args, only_config=False): def _load_model_default(builder_args, only_config=False): assert not builder_args.gguf_path - with torch.device("meta"): - if builder_args.params_path: - model = Transformer.from_params(builder_args.params_path) - elif builder_args.params_table: - model = Transformer.from_table(builder_args.params_table) - else: - model = Transformer.from_name(builder_args.checkpoint_path.parent.name) - + model = _init_model_on_meta_device(builder_args) # checkpoint = torch.load(str(builder_args.checkpoint_path), mmap=True, weights_only=True) cps = [] if builder_args.checkpoint_dir is not None: From ab41031e7f8c0d3e1d613fc2a7ba6848df771d2e Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 2 Jul 2024 11:19:50 -0700 Subject: [PATCH 6/7] [Distributed Inference] Make torch run work for torchchat --- build/builder.py | 9 ++--- distributed/__init__.py | 1 + distributed/parallel_config.py | 1 + distributed/parallelize_llama.py | 66 +++++++++++++++++--------------- distributed/utils.py | 15 ++------ generate.py | 7 ++++ 6 files changed, 52 insertions(+), 47 deletions(-) diff --git a/build/builder.py b/build/builder.py index 583a1a52f..d2ee10ea9 100644 --- a/build/builder.py +++ b/build/builder.py @@ -21,7 +21,7 @@ from build.model import Transformer from build.utils import device_sync, is_cpu_device, is_cuda_or_cpu_device, name_to_dtype -from distributed import parallelize_llama, ParallelDims +from distributed import parallelize_llama, ParallelDims, init_distributed @dataclass @@ -359,12 +359,11 @@ def _load_model(builder_args, only_config=False): pp=1, world_size=world_size, ) - device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") - torch.cuda.set_device(device) - init_distributed(job_config) + init_distributed() + world_mesh = parallel_dims.build_mesh(device_type="cuda") print("Applying model parallel to model ...") - parallelize_llama(model) + parallelize_llama(model, world_mesh, parallel_dims) model = model.to(device=builder_args.device, dtype=builder_args.precision) return model.eval() diff --git a/distributed/__init__.py b/distributed/__init__.py index 64cd5f22d..2c5417404 100644 --- a/distributed/__init__.py +++ b/distributed/__init__.py @@ -6,3 +6,4 @@ from distributed.parallelize_llama import parallelize_llama from distributed.parallel_config import ParallelDims +from distributed.utils import init_distributed diff --git a/distributed/parallel_config.py b/distributed/parallel_config.py index d1d8aa9c7..048d4809c 100644 --- a/distributed/parallel_config.py +++ b/distributed/parallel_config.py @@ -6,6 +6,7 @@ from dataclasses import dataclass, field from torch.distributed.device_mesh import init_device_mesh +from distributed.utils import logger @dataclass class ParallelDims: diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index e2b73d0dd..e8570f17f 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -10,12 +10,13 @@ parallelize_module, PrepareModuleInput, RowwiseParallel, - SequenceParallel, ) import torch.nn as nn +from torch.distributed._tensor import Replicate, Shard from distributed.parallel_config import ParallelDims from torch.distributed.device_mesh import DeviceMesh +from distributed.utils import logger def apply_tp( @@ -43,53 +44,56 @@ def apply_tp( tp_mesh = world_mesh["tp"] - # 1. Parallelize the first embedding and the last linear proj layer - # 2. Parallelize the root norm layer over the sequence dim - # 3. Shard the first transformer block's inputs - model = parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Replicate(), - use_local_output=True, - ), - "norm": SequenceParallel(), - }, - ) + # TODO: To figure out the TP for the tok_embedding and the linear proj layer. + # # 1. Parallelize the first embedding and the last linear proj layer + # # 2. Parallelize the root norm layer over the sequence dim + # # 3. Shard the first transformer block's inputs + # model = parallelize_module( + # model, + # tp_mesh, + # { + # "tok_embeddings": RowwiseParallel( + # input_layouts=Replicate(), + # output_layouts=Replicate(), + # ), + # "output": ColwiseParallel( + # input_layouts=Shard(1), + # output_layouts=Replicate(), + # use_local_output=True, + # ), + # }, + # ) # Apply tensor + sequence parallelism to every transformer block - for layer_id, transformer_block in model.layers.items(): + for transformer_block in model.layers: layer_plan = { - "attention": prepare_module_input( - input_layouts=(Shard(1), None), + "attention": PrepareModuleInput( + input_layouts=(Replicate(), None), desired_input_layouts=(Replicate(), None), ), "attention.wq": ColwiseParallel(), "attention.wk": ColwiseParallel(), "attention.wv": ColwiseParallel(), - "attention.wo": RowwiseParallel(output_layouts=Shard(1)), - "attention_norm": SequenceParallel(), - "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), + "attention.wo": RowwiseParallel( + output_layouts=Replicate(), + use_local_output=True, + ), + "feed_forward": PrepareModuleInput( + input_layouts=(Replicate(),), desired_input_layouts=(Replicate(),), ), "feed_forward.w1": ColwiseParallel(), - "feed_forward.w2": RowwiseParallel(output_layouts=Shard(1)), + "feed_forward.w2": RowwiseParallel( + output_layouts=Replicate(), + use_local_output=True + ), "feed_forward.w3": ColwiseParallel(), - "ffn_norm": SequenceParallel(), } # Adjust attention module to use the local number of heads attn_layer = transformer_block.attention attn_layer.n_heads = attn_layer.n_heads // tp_mesh.size() attn_layer.n_local_heads = attn_layer.n_local_heads // tp_mesh.size() - attn_layer.n_kv_heads = attn_layer.n_kv_heads // tp_mesh.size() parallelize_module( module=transformer_block, @@ -125,6 +129,6 @@ def parallelize_llama( """ if parallel_dims.tp_enabled: - model = apply_tp(model, world_mesh, parallel_dims) + model = apply_tp(model, world_mesh) return model diff --git a/distributed/utils.py b/distributed/utils.py index 71b68f94a..9a02dd4cd 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -8,6 +8,8 @@ from datetime import timedelta import torch +import logging +logger = logging.getLogger() def _warn_overwrite_env(env, val): @@ -25,24 +27,15 @@ def _warn_overwrite_env(env, val): SKIP_CLEANUP = "3" -def init_distributed(job_config): +def init_distributed(init_timeout_seconds: int = 120): # FlightRecorder is incompatible with =1 mode where watchdog aborts work, must use =3 (skipcleanup) # to get flight recorder dumps. See https://github.com/pytorch/pytorch/issues/121055 # This could be done only when flight recorder is enabled, but its nice to be consistent to avoid subtle # behavior differences _warn_overwrite_env(ASYNC_ERROR_HANDLING, SKIP_CLEANUP) - # enable torch nccl flight recorder in the mode that would dump files if timeout is detected - _warn_overwrite_env(TRACE_BUFFER_SIZE, str(job_config.comm.trace_buf_size)) - if job_config.comm.trace_buf_size > 0: - # dump on timeout by default if trace buffer is enabled - _warn_overwrite_env(DUMP_ON_TIMEOUT, "1") - dump_dir = f"{job_config.job.dump_folder}/comm_trace" - os.makedirs(dump_dir, exist_ok=True) - _warn_overwrite_env(TRACE_FILE, f"{dump_dir}/rank_") - torch.distributed.init_process_group( - "nccl", timeout=timedelta(seconds=job_config.comm.init_timeout_seconds) + "nccl", timeout=timedelta(seconds=init_timeout_seconds) ) # to mitigate the memory issue that collectives using diff --git a/generate.py b/generate.py index 3e042f6b1..acca657b7 100644 --- a/generate.py +++ b/generate.py @@ -8,6 +8,7 @@ import logging import sys import time +import os from dataclasses import dataclass from pathlib import Path from typing import List, Optional, Tuple @@ -504,6 +505,12 @@ def _main( # print = lambda *args, **kwargs: None print(f"Using device={builder_args.device} {get_device_info(builder_args.device)}") + # If using distributed inference we cannot just assign device to be cuda + # because it will be assigned to cuda:0 by default. We need explicitely set + # the device to be the local rank. + if builder_args.use_distributed: + device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}") + torch.cuda.set_device(device) set_precision(builder_args.precision) is_speculative = speculative_builder_args.checkpoint_path is not None From 655ea0f8fa7ceb994b5b269827532b0a0d4c1b07 Mon Sep 17 00:00:00 2001 From: fduwjj Date: Tue, 2 Jul 2024 13:11:28 -0700 Subject: [PATCH 7/7] Address comments --- distributed/parallelize_llama.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/distributed/parallelize_llama.py b/distributed/parallelize_llama.py index e8570f17f..d1cf8fd80 100644 --- a/distributed/parallelize_llama.py +++ b/distributed/parallelize_llama.py @@ -46,8 +46,7 @@ def apply_tp( # TODO: To figure out the TP for the tok_embedding and the linear proj layer. # # 1. Parallelize the first embedding and the last linear proj layer - # # 2. Parallelize the root norm layer over the sequence dim - # # 3. Shard the first transformer block's inputs + # # 2. Shard the first transformer block's inputs # model = parallelize_module( # model, # tp_mesh, @@ -64,7 +63,7 @@ def apply_tp( # }, # ) - # Apply tensor + sequence parallelism to every transformer block + # Apply tensor parallelism to every transformer block for transformer_block in model.layers: layer_plan = { "attention": PrepareModuleInput(