From edde72ea9ceede1330940a29e7f4919ec9adea99 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Fri, 23 Aug 2024 23:22:31 -0700 Subject: [PATCH 1/3] Initial add of distributed model Use parallelize_module in model [ghstack-poisoned] --- build/model_dist.py | 267 ++++++++++++++++++++++++++++++++++++++++++++ dist_run.py | 35 ++++++ 2 files changed, 302 insertions(+) create mode 100644 build/model_dist.py create mode 100644 dist_run.py diff --git a/build/model_dist.py b/build/model_dist.py new file mode 100644 index 000000000..8e5d72fde --- /dev/null +++ b/build/model_dist.py @@ -0,0 +1,267 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +from pathlib import Path +from typing import Dict, Optional + +import torch +import torch.nn as nn + +from torch import Tensor +from torch.nn import functional as F +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel + +from build.utils import find_multiple + +from build.model import TransformerArgs, KVCache, apply_rotary_emb, precompute_freqs_cis + +config_path = Path(f"{str(Path(__file__).parent)}/known_model_params") + + +# Use DTensor as output, by default +Colwise = ColwiseParallel(use_local_output=False) +Rowwise = RowwiseParallel(use_local_output=False) + + +class Transformer(nn.Module): + def __init__(self, config: TransformerArgs) -> None: + super().__init__() + self.config = config + + tok_embeddings = nn.Embedding(config.vocab_size, config.dim) + self.tok_embeddings = parallelize_module( + tok_embeddings, + plan=RowwiseParallel(input_layouts=Replicate()), + ) + self.layers = nn.ModuleList( + TransformerBlock(config) for _ in range(config.n_layers) + ) + self.norm = RMSNorm(config.dim, eps=config.norm_eps) + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) + + # self.freqs_cis: Optional[Tensor] = None + # self.mask_cache: Optional[Tensor] = None + self.max_batch_size = -1 + self.max_seq_length = -1 + + def setup_caches(self, max_batch_size, max_seq_length): + if ( + self.max_seq_length >= max_seq_length + and self.max_batch_size >= max_batch_size + ): + return + head_dim = self.config.dim // self.config.n_heads + max_seq_length = find_multiple(max_seq_length, 8) + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + for b in self.layers: + b.attention.kv_cache = KVCache( + max_batch_size, max_seq_length, self.config.n_local_heads, head_dim + ) + + freqs_cis = precompute_freqs_cis( + self.config.dim // self.config.n_heads, + self.config.block_size * 2, + self.config.rope_base, + use_scaled = self.config.use_scaled_rope, + ) + self.register_buffer("freqs_cis", freqs_cis, persistent=True) + causal_mask = torch.tril( + torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool) + ) + self.register_buffer("causal_mask", causal_mask, persistent=True) + + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.causal_mask[None, None, input_pos] + freqs_cis = self.freqs_cis[input_pos] + x: DTensor = self.tok_embeddings(idx) + # TODO: sequence parallelize this + + for _, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + # print(f"logits shape: {logits.shape}") + return logits + + @classmethod + def from_name(cls, name: str): + return cls(TransformerArgs.from_name(name)) + + @classmethod + def from_table(cls, name: str): + return cls(TransformerArgs.from_table(name)) + + @classmethod + def from_params(cls, params_path: str): + return cls(TransformerArgs.from_params(params_path)) + + @classmethod + def from_gguf(cls, gguf_path: str, **kwargs): + from build.gguf_loader import load_model_and_state_dict + + model, state_dict = load_model_and_state_dict(gguf_path, **kwargs) + if state_dict != {}: + model.load_state_dict(state_dict, assign=True) + return model + + +class TransformerBlock(nn.Module): + def __init__(self, config: TransformerArgs) -> None: + super().__init__() + self.attention = Attention(config) + self.feed_forward = FeedForward(config) + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) + self.attention_norm = RMSNorm(config.dim, config.norm_eps) + + def forward( + self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor + ) -> Tensor: + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) + out = h + self.feed_forward(self.ffn_norm(h)) + return out + + +class Attention(nn.Module): + def __init__(self, config: TransformerArgs): + super().__init__() + assert config.dim % config.n_heads == 0 + + # key, query, value projections for all heads, but in a batch + # total_head_dim = (config.n_heads + 2 * config.n_local_heads) * config.head_dim + # self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) + wq = nn.Linear(config.dim, config.n_heads * config.head_dim, bias=False) + wk = nn.Linear( + config.dim, config.n_local_heads * config.head_dim, bias=False + ) + wv = nn.Linear( + config.dim, config.n_local_heads * config.head_dim, bias=False + ) + wo = nn.Linear(config.dim, config.dim, bias=False) + + self.wq = parallelize_module(wq, plan=Colwise) + self.wk = parallelize_module(wk, plan=Colwise) + self.wv = parallelize_module(wv, plan=Colwise) + self.wo = parallelize_module(wo, plan=Rowwise) + + self.kv_cache = None + + self.n_heads = config.n_heads + self.head_dim = config.head_dim + self.n_local_heads = config.n_local_heads + self.dim = config.dim + self._register_load_state_dict_pre_hook(self.load_hook) + + def load_hook(self, state_dict, prefix, *args): + # if prefix + "wq.weight" in state_dict: + # wq = state_dict.pop(prefix + "wq.weight") + # wk = state_dict.pop(prefix + "wk.weight") + # wv = state_dict.pop(prefix + "wv.weight") + # state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) + + if prefix + "wqkv.weight" in state_dict: + wqkv = state_dict.pop(prefix + "wqkv.weight") + q_size = self.n_heads * self.head_dim + kv_size = self.n_local_heads * self.head_dim + wq, wk, wv = torch.split(wqkv, (q_size, kv_size, kv_size), dim=0) + state_dict[prefix + "wq.weight"] = wq + state_dict[prefix + "wk.weight"] = wk + state_dict[prefix + "wv.weight"] = wv + + return + + def _unfuse_wqkv_state_dict( + state_dict: Dict[str, torch.Tensor], + dim: int, + ): + for key in list(state_dict): + if key.endswith("wqkv.weight"): + tensor = state_dict[key] + wq_key = key.replace("wqkv.weight", "wq.weight") + state_dict[wq_key] = tensor[:dim] + wk_key = key.replace("wqkv.weight", "wk.weight") + wv_key = key.replace("wqkv.weight", "wv.weight") + wk, wv = tensor[dim:].chunk(2, 0) + state_dict[wk_key] = wk + state_dict[wv_key] = wv + state_dict.pop(key) + else: + continue + + _unfuse_wqkv_state_dict(state_dict, self.dim) + + def forward( + self, + x: Tensor, + freqs_cis: Tensor, + mask: Tensor, + input_pos: Optional[Tensor] = None, + ) -> Tensor: + bsz, seqlen, _ = x.shape + + q: DTensor = self.wq(x) + k: DTensor = self.wk(x) + v: DTensor = self.wv(x) + # We use `to_local()` to convert DTensor back to regular Tensor + q, k, v = q.to_local(), k.to_local(), v.to_local() + # kv_size = self.n_local_heads * self.head_dim + # q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) + + q = q.view(bsz, seqlen, -1, self.head_dim) + k = k.view(bsz, seqlen, -1, self.head_dim) + v = v.view(bsz, seqlen, -1, self.head_dim) + + q = apply_rotary_emb(q, freqs_cis) + k = apply_rotary_emb(k, freqs_cis) + + q, k, v = (x.transpose(1, 2) for x in (q, k, v)) + + #if self.kv_cache is not None: + # k, v = self.kv_cache.update(input_pos, k, v) + + k = k.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) + v = v.repeat_interleave(self.n_heads // self.n_local_heads, dim=1) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + + y: DTensor = self.wo(y) + # TODO: sequence parallelize this + return y.full_tensor() + + +class FeedForward(nn.Module): + def __init__(self, config: TransformerArgs) -> None: + super().__init__() + w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) + w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) + w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) + self.w1 = parallelize_module(w1, plan=Colwise) + self.w2 = parallelize_module(w2, plan=Rowwise) + self.w3 = parallelize_module(w3, plan=Colwise) + + def forward(self, x: Tensor) -> Tensor: + y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) + # y is a DTensor with Partial placement; + # we convert its placement to Replicate and convert it back to a regular + # Tensor. `full_tensor` is the API that does both. + # TODO: sequence parallelize this + return y.full_tensor() + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-5): + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) + + def forward(self, x: Tensor) -> Tensor: + output = self._norm(x.float()).type_as(x) + return output * self.weight diff --git a/dist_run.py b/dist_run.py new file mode 100644 index 000000000..7a4677dc8 --- /dev/null +++ b/dist_run.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist + +from build.model import TransformerArgs +from build.model_dist import Transformer + +# Model config +config = TransformerArgs.from_name("Transformer-2-7b-chat-hf") +print(config) + +# Construct a device mesh with available devices (multi-host or single host) +device_mesh = dist.init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) +rank = dist.get_rank() +device = torch.device(f"cuda:{rank}") + +# Create parallel model with device_mesh context +with device: + with device_mesh: + model = Transformer(config) + model.setup_caches(1, 4096) + +print(model) + +# Distributed run +input_ids = torch.randint(0, config.vocab_size, (1, 4096), device=device) +input_pos = torch.arange(0, 4096, device=device) +output = model(input_ids, input_pos) +dist.destroy_process_group() +print(f"Rank {rank} completes.") From 98c0f92c54693106291ecac0c1a7c880411fb636 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Sun, 25 Aug 2024 10:21:18 -0700 Subject: [PATCH 2/3] Update on "Initial add of distributed model" Use `parallelize_module` in model. Added files: `model_dist.py`: a mirror of model.py with Tensor Parallelism baked in. `dist_run.py`: toy example of how to run the model in distributed way. Test: `torchrun --nproc-per-node 2 dist_run.py` [ghstack-poisoned] --- build/model_dist.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/build/model_dist.py b/build/model_dist.py index 8e5d72fde..1fa71498a 100644 --- a/build/model_dist.py +++ b/build/model_dist.py @@ -12,6 +12,7 @@ from torch import Tensor from torch.nn import functional as F from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel, RowwiseParallel from build.utils import find_multiple @@ -25,16 +26,25 @@ Colwise = ColwiseParallel(use_local_output=False) Rowwise = RowwiseParallel(use_local_output=False) +# Device mesh context +device_mesh = None + class Transformer(nn.Module): def __init__(self, config: TransformerArgs) -> None: super().__init__() self.config = config + # Get device mesh + global device_mesh + if device_mesh is None: + device_mesh = _mesh_resources.get_current_mesh() + tok_embeddings = nn.Embedding(config.vocab_size, config.dim) self.tok_embeddings = parallelize_module( tok_embeddings, - plan=RowwiseParallel(input_layouts=Replicate()), + device_mesh, + RowwiseParallel(input_layouts=Replicate()), ) self.layers = nn.ModuleList( TransformerBlock(config) for _ in range(config.n_layers) @@ -143,10 +153,10 @@ def __init__(self, config: TransformerArgs): ) wo = nn.Linear(config.dim, config.dim, bias=False) - self.wq = parallelize_module(wq, plan=Colwise) - self.wk = parallelize_module(wk, plan=Colwise) - self.wv = parallelize_module(wv, plan=Colwise) - self.wo = parallelize_module(wo, plan=Rowwise) + self.wq = parallelize_module(wq, device_mesh, Colwise) + self.wk = parallelize_module(wk, device_mesh, Colwise) + self.wv = parallelize_module(wv, device_mesh, Colwise) + self.wo = parallelize_module(wo, device_mesh, Rowwise) self.kv_cache = None @@ -240,9 +250,9 @@ def __init__(self, config: TransformerArgs) -> None: w1 = nn.Linear(config.dim, config.hidden_dim, bias=False) w2 = nn.Linear(config.hidden_dim, config.dim, bias=False) w3 = nn.Linear(config.dim, config.hidden_dim, bias=False) - self.w1 = parallelize_module(w1, plan=Colwise) - self.w2 = parallelize_module(w2, plan=Rowwise) - self.w3 = parallelize_module(w3, plan=Colwise) + self.w1 = parallelize_module(w1, device_mesh, Colwise) + self.w2 = parallelize_module(w2, device_mesh, Rowwise) + self.w3 = parallelize_module(w3, device_mesh, Colwise) def forward(self, x: Tensor) -> Tensor: y: DTensor = self.w2(F.silu(self.w1(x)) * self.w3(x)) From 4b8d5d859b44a17f7b5219f01346586de3a63d23 Mon Sep 17 00:00:00 2001 From: Ke Wen Date: Mon, 26 Aug 2024 09:52:23 -0700 Subject: [PATCH 3/3] Update on "Initial add of distributed model" Use `parallelize_module` in model. Added files: `model_dist.py`: a mirror of model.py with Tensor Parallelism baked in. `dist_run.py`: toy example of how to run the model in distributed way. Test: `torchrun --nproc-per-node 2 dist_run.py` [ghstack-poisoned] --- build/model_dist.py | 1 + dist_run.py | 48 ++++++++++++++++++++++++--------------------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/build/model_dist.py b/build/model_dist.py index 1fa71498a..b46c68989 100644 --- a/build/model_dist.py +++ b/build/model_dist.py @@ -230,6 +230,7 @@ def forward( q, k, v = (x.transpose(1, 2) for x in (q, k, v)) + # TODO: enable kv cache #if self.kv_cache is not None: # k, v = self.kv_cache.update(input_pos, k, v) diff --git a/dist_run.py b/dist_run.py index 7a4677dc8..d1f4f213f 100644 --- a/dist_run.py +++ b/dist_run.py @@ -11,25 +11,29 @@ from build.model_dist import Transformer # Model config -config = TransformerArgs.from_name("Transformer-2-7b-chat-hf") -print(config) - -# Construct a device mesh with available devices (multi-host or single host) -device_mesh = dist.init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) -rank = dist.get_rank() -device = torch.device(f"cuda:{rank}") - -# Create parallel model with device_mesh context -with device: - with device_mesh: - model = Transformer(config) - model.setup_caches(1, 4096) - -print(model) - -# Distributed run -input_ids = torch.randint(0, config.vocab_size, (1, 4096), device=device) -input_pos = torch.arange(0, 4096, device=device) -output = model(input_ids, input_pos) -dist.destroy_process_group() -print(f"Rank {rank} completes.") +def main(): + config = TransformerArgs.from_name("Transformer-2-7b-chat-hf") + print(config) + + # Construct a device mesh with available devices (multi-host or single host) + device_mesh = dist.init_device_mesh("cuda", (2,), mesh_dim_names=("tp",)) + rank = dist.get_rank() + device = torch.device(f"cuda:{rank}") + + # Create parallel model with device_mesh context + with device: + with device_mesh: + model = Transformer(config) + model.setup_caches(1, 4096) + + print(model) + + # Distributed run + input_ids = torch.randint(0, config.vocab_size, (1, 4096), device=device) + input_pos = torch.arange(0, 4096, device=device) + output = model(input_ids, input_pos) + dist.destroy_process_group() + print(f"Rank {rank} completes.") + +if __name__ == "__main__": + main()