From d8d42b53426c0237039151225c9497cda515cbeb Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 31 Aug 2023 23:15:38 -0700 Subject: [PATCH 1/4] Add DTensor LLaMA inference model: simple_gpt --- torchbenchmark/models/ADDING_MODELS.md | 24 +- torchbenchmark/models/simple_gpt/__init__.py | 104 +++++ torchbenchmark/models/simple_gpt/install.py | 9 + .../models/simple_gpt/metadata.yaml | 5 + torchbenchmark/models/simple_gpt/model.py | 376 ++++++++++++++++++ torchbenchmark/models/simple_gpt/origin | 1 + .../models/simple_gpt/requirements.txt | 1 + torchbenchmark/models/simple_gpt/utils.py | 5 + 8 files changed, 513 insertions(+), 12 deletions(-) create mode 100644 torchbenchmark/models/simple_gpt/__init__.py create mode 100644 torchbenchmark/models/simple_gpt/install.py create mode 100644 torchbenchmark/models/simple_gpt/metadata.yaml create mode 100644 torchbenchmark/models/simple_gpt/model.py create mode 100644 torchbenchmark/models/simple_gpt/origin create mode 100644 torchbenchmark/models/simple_gpt/requirements.txt create mode 100644 torchbenchmark/models/simple_gpt/utils.py diff --git a/torchbenchmark/models/ADDING_MODELS.md b/torchbenchmark/models/ADDING_MODELS.md index 737bed5243..9681fe49ec 100644 --- a/torchbenchmark/models/ADDING_MODELS.md +++ b/torchbenchmark/models/ADDING_MODELS.md @@ -9,16 +9,16 @@ ## Detailed steps ### Adding the model code -The intent is to preserve the original user code as much as possible while +The intent is to preserve the original user code as much as possible while adding support for a standardized interface to the benchmark suite and making sure the code can run from any directory and in a process with other models. In many case it is fine to simply copy the entire original repo into a subdirectory -as a starting point, paying attention to avoid the .git folder, and not to add any +as a starting point, paying attention to avoid the .git folder, and not to add any large unnecessary data files unintentionally. The subdirectory name should be a valid Python identifier because it will become a module in Python and needs to be importable. -Create a new file 'origin' that contains the url to the git repo you're copying, +Create a new file 'origin' that contains the url to the git repo you're copying, so it's easy to trace the code back to where it came from. #### Wrapping your model in \_\_init\_\_.py @@ -34,22 +34,22 @@ Take care to set the random seed like [here](https://github.com/pytorch/benchmar #### A minimal new model addition A bare miminum example you can follow is https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/phlippe_resnet -The functions you specifically need to implement are +The functions you specifically need to implement are 1. `__init__()` which is responsible for initalizing your `nn.Module` 2. `get_module()` which is responsible for returning the initialized `nn.Module` and an example input 3. `train()` which is a training loop, you can return a `NotImplementedError()` if your example is inference only. If your training loop can be encapsulated by a `forward()`, `backward()`, and `optimizer_step()`, you need not redefine `train()`. Instead, please make sure your model provides functions `forward()`, `backward()`, and `optimizer_step()` along with an - attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details. + attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details. 4. `eval()` which showcases a simple inference -Optionally, if you would like to be able to customize different optimizers for your model, feel free +Optionally, if you would like to be able to customize different optimizers for your model, feel free to override the BenchmarkModel's base class' default `get_optimizer()` and `set_optimizer(optimizer)` -methods. +methods. ### Preparing install.py and dependencies Simply put, install.py should be a one stop shop to install all the dependencies -for your model, __except torch, torchvision, torchaudio__ which should be assumed to +for your model, __except torch, torchvision, torchaudio__ which should be assumed to have been installed by an outsider (the benchmark CI). - Avoid pinning packages to specific versions with == without good reason, as the @@ -65,7 +65,7 @@ not easy to build, there may be easier models to target. [Example install.py](BERT_pytorch/install.py) ### Mini-dataset -By the time install.py script runs, a miniature version of the dataset is expected to be +By the time install.py script runs, a miniature version of the dataset is expected to be staged and ready for use. It's fine to use install.py to download and prepare the data if the download is quick. Otherwise, prepare the dataset manually, checking in the required artifacts and modifying the \_\_init\_\_.py script as needed to use them. @@ -109,11 +109,11 @@ version. ### Test -After you've submitted your new model, suppose it was called `new_model` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models` +After you've submitted your new model, suppose it was called `` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models` 1. `cd benchmark` 2. `python install.py` -3. `python run.py model -d cuda` and `python run.py model -d cpu` -3. `python test.py -k "model_"` following the format from here https://github.com/pytorch/benchmark#using-testpy +3. `python run.py -d cuda` and `python run.py -d cpu` +3. `python test.py -k "test__"` following the format from here https://github.com/pytorch/benchmark#using-testpy And thank you for contributing to torchbench! diff --git a/torchbenchmark/models/simple_gpt/__init__.py b/torchbenchmark/models/simple_gpt/__init__.py new file mode 100644 index 0000000000..8eaa59165b --- /dev/null +++ b/torchbenchmark/models/simple_gpt/__init__.py @@ -0,0 +1,104 @@ +import os + +import torch +import lightning as L +from torch.distributed._tensor import DeviceMesh +from torch.distributed.tensor.parallel import parallelize_module +from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel +from torchbenchmark.tasks import NLP + +from ...util.model import BenchmarkModel +from .model import LLaMA +from .utils import LOCAL_RANK, LOCAL_WORLD_SIZE + + +class Model(BenchmarkModel): + task = NLP.GENERATION + DEFAULT_EVAL_BSIZE = 1 + + def __init__(self, test, device, batch_size=None, extra_args=[]): + super().__init__( + test=test, + device=device, + batch_size=batch_size, + extra_args=extra_args, + ) + + error = self.validate_environment() + if error: + # per ADDING_MODELS.md, convention is to fail silently in __init__ and raise in eval + return + + fabric = L.Fabric(devices=[LOCAL_RANK], precision="bf16-true") + with fabric.init_module(empty_init=True): + self.model = LLaMA.from_name("7B") + + # Tensor parallelism using DTensor + mesh = DeviceMesh("cuda", list(range(LOCAL_WORLD_SIZE))) + for block in self.model.transformer.h: + # prepare attention weights to be parallelized + block.attn.prepare_qkv_for_dtensor_tp() + + parallelize_module( + module=block, + device_mesh=mesh, + parallelize_plan={ + "attn.c_attn_q": ColwiseParallel(), + "attn.c_attn_k": ColwiseParallel(), + "attn.c_attn_v": ColwiseParallel(), + "attn.c_proj": RowwiseParallel(), + "mlp.c_fc1": ColwiseParallel(), + "mlp.c_fc2": ColwiseParallel(), + "mlp.c_proj": RowwiseParallel(), + }, + tp_mesh_dim=0, + ) + + max_batch_size = 1 + self.model.setup_caches( + max_batch_size=max_batch_size, max_seq_length=self.model.config.block_size + ) + + prompt_size = 10 + idx = torch.randint( + self.model.config.vocab_size, + (max_batch_size, prompt_size), + dtype=torch.int32, + device=device, + ) + input_pos = torch.arange(prompt_size, device=device) + self.example_inputs = [idx, input_pos] + + def get_module(self): + return self.model, self.example_inputs + + def train(self): + raise NotImplementedError("Training not supported for this model") + + def validate_environment(self): + if not torch.cuda.is_available() or "cuda" not in self.device: + return NotImplementedError("Model requires CUDA") + + if not torch.cuda.is_bf16_supported(): + return NotImplementedError("Model requires BF16") + + if LOCAL_WORLD_SIZE != torch.cuda.device_count(): + return NotImplementedError( + f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {LOCAL_WORLD_SIZE}." + ) + + return None + + def eval(self): + error = self.validate_environment() + if error: + raise error + + with torch.no_grad(): + out = self.model(*self.example_inputs) + return (out,) + + +if __name__ == "__main__": + model = Model(test="eval", device="cuda") + model.eval() diff --git a/torchbenchmark/models/simple_gpt/install.py b/torchbenchmark/models/simple_gpt/install.py new file mode 100644 index 0000000000..be308ead48 --- /dev/null +++ b/torchbenchmark/models/simple_gpt/install.py @@ -0,0 +1,9 @@ +import subprocess +import sys + + +def pip_install_requirements(): + subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) + +if __name__ == '__main__': + pip_install_requirements() diff --git a/torchbenchmark/models/simple_gpt/metadata.yaml b/torchbenchmark/models/simple_gpt/metadata.yaml new file mode 100644 index 0000000000..31de3c8f55 --- /dev/null +++ b/torchbenchmark/models/simple_gpt/metadata.yaml @@ -0,0 +1,5 @@ +eval_benchmark: false +eval_deterministic: false +eval_nograd: true +train_benchmark: false +train_deterministic: false diff --git a/torchbenchmark/models/simple_gpt/model.py b/torchbenchmark/models/simple_gpt/model.py new file mode 100644 index 0000000000..5e77c8c082 --- /dev/null +++ b/torchbenchmark/models/simple_gpt/model.py @@ -0,0 +1,376 @@ +"""Full definition of a LLaMA Language Model, all of it in this single file. + +Based on the nanoGPT implementation: https://github.com/karpathy/nanoGPT. +""" +# mypy: ignore-errors +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union +from typing_extensions import Self + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from .utils import LOCAL_WORLD_SIZE + + + +MaskCache = torch.Tensor +RoPECache = torch.Tensor +KVCache = Tuple[torch.Tensor, torch.Tensor] + +def find_multiple(n: int, k: int) -> int: + if n % k == 0: + return n + return n + k - (n % k) + +class LinearInt8(torch.nn.Module): + __constants__ = ['in_features', 'out_features'] + in_features: int + out_features: int + weight: torch.Tensor + + def __init__(self, in_features: int, out_features: int, bias: bool = True, + device=None, dtype=None) -> None: + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8)) + # if bias: + # self.register_buffer("bias", torch.empty(out_features, **factory_kwargs, dtype=torch.int8)) + # else: + # self.bias('bias', None) + + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return F.linear(input, self.weight.to(dtype=input.dtype)) + +# nn.Linear = LinearInt8 + +@dataclass +class LLaMAConfig: + block_size: int = 2048 + vocab_size: int = 32000 + padded_vocab_size: Optional[int] = None + n_layer: int = 32 + n_head: int = 32 + n_embd: int = 4096 + + def __post_init__(self): + if self.padded_vocab_size is None: + self.padded_vocab_size = find_multiple(self.vocab_size, 64) + + @classmethod + def from_name(cls, name: str) -> Self: + return cls(**llama_configs[name]) + + +llama_configs = { + "7B": dict(n_layer=32, n_head=32, n_embd=4096), + "13B": dict(n_layer=40, n_head=40, n_embd=5120), + "30B": dict(n_layer=60, n_head=52, n_embd=6656), + "65B": dict(n_layer=80, n_head=64, n_embd=8192), +} + +class KVCache(nn.Module): + def __init__(self, max_batch_size, max_seq_length, n_heads, head_size, device='cuda', dtype=torch.bfloat16): + super().__init__() + cache_shape = (max_batch_size, n_heads, max_seq_length, head_size) + self.k_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) + self.v_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) + + def update(self, input_pos, k_val, v_val): + # input_pos: [S], k_val: [B, H, S, D] + assert input_pos.shape[0] == k_val.shape[2] + + self.k_cache[:, :, input_pos] = k_val + self.v_cache[:, :, input_pos] = v_val + + return self.k_cache, self.v_cache + +class KVCacheAggregator(nn.Module): + def __init__(self): + super().__init__() + self.kv_caches = nn.ModuleList([]) + + def initialize(self,layers, max_batch_size, max_seq_length, n_heads, head_size, device='cuda', dtype=torch.bfloat16): + cache_shape = (max_batch_size, n_heads, max_seq_length, head_size) + self.kv_caches = nn.ModuleList([KVCache(max_batch_size, max_seq_length, n_heads, head_size) for _ in range(layers)]) + + def __getitem__(self, idx): + return self.kv_caches[idx] + + def clear(self): + self.kv_caches = nn.ParameterList([]) + +class LLaMA(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + assert config.padded_vocab_size is not None + self.config = config + + self.lm_head = nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) + self.transformer = nn.ModuleDict( + dict( + wte=nn.Embedding(config.padded_vocab_size, config.n_embd), + h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + ln_f=RMSNorm(config.n_embd), + ) + ) + + self.rope_cache: Optional[RoPECache] = None + self.mask_cache: Optional[MaskCache] = None + self.kv_caches = KVCacheAggregator() + self.max_batch_size = None + self.max_seq_length = None + + def setup_caches(self, max_batch_size, max_seq_length, device='cuda', dtype=torch.bfloat16): + n_embd = self.config.n_embd // LOCAL_WORLD_SIZE + n_head = self.config.n_head // LOCAL_WORLD_SIZE + head_size = n_embd // n_head + + self.max_seq_length = max_seq_length + self.max_batch_size = max_batch_size + self.kv_caches.initialize(layers=self.config.n_layer, max_batch_size=max_batch_size, max_seq_length=max_seq_length, n_heads=n_head, head_size=head_size) + + self.rope_cache = build_rope_cache( + seq_len=self.config.block_size, + n_elem=head_size, + dtype=dtype, + device=device, + ) + ones = torch.ones((self.config.block_size, self.config.block_size), device=device, dtype=torch.bool) + self.mask_cache = torch.tril(ones).unsqueeze(0).unsqueeze(0) + + def _init_weights(self, module: nn.Module) -> None: + if isinstance(module, nn.Linear): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) + elif isinstance(module, nn.Embedding): + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer)) + + def forward( + self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[KVCache]]]: + B, T = idx.size() + assert self.rope_cache is not None, "Caches must be initialized first" + + block_size = self.config.block_size + max_seq_length = self.max_seq_length + if max_seq_length is None: + max_seq_length = block_size + + assert T <= max_seq_length, f"Cannot forward sequence of length {T}, max seq length is only {max_seq_length}" + assert max_seq_length <= block_size, f"Cannot attend to {max_seq_length}, block size is only {block_size}" + assert T <= block_size, f"Cannot forward sequence of length {T}, block size is only {block_size}" + + rope = self.rope_cache.index_select(0, input_pos) + mask = self.mask_cache.index_select(2, input_pos) + mask = mask[:, :, :, :max_seq_length] + + # forward the model itself + x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + for i, block in enumerate(self.transformer.h): + x, new_kv_cache = block(x, rope, mask, max_seq_length, input_pos, self.kv_caches[i]) + + x = self.transformer.ln_f(x) + + logits = self.lm_head(x) # (b, t, vocab_size) + + return logits + + @classmethod + def from_name(cls, name: str) -> Self: + return cls(LLaMAConfig.from_name(name)) + + def reset_cache(self) -> None: + self.kv_caches.clear() + + +class Block(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + self.rms_1 = RMSNorm(config.n_embd) + self.attn = CausalSelfAttention(config) + self.rms_2 = RMSNorm(config.n_embd) + self.mlp = MLP(config) + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: MaskCache, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + h, new_kv_cache = self.attn(self.rms_1(x), rope, mask, max_seq_length, input_pos, kv_cache) + x = x + h + x = x + self.mlp(self.rms_2(x)) + return x, new_kv_cache + + +class CausalSelfAttention(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + assert config.n_embd % config.n_head == 0 + self.config = config + + # key, query, value projections for all heads, but in a batch + self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False) + # output projection + self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False) + + self.n_head = config.n_head + self.n_embd = config.n_embd + self.block_size = config.block_size + + def forward( + self, + x: torch.Tensor, + rope: RoPECache, + mask: MaskCache, + max_seq_length: int, + input_pos: Optional[torch.Tensor] = None, + kv_cache: Optional[KVCache] = None, + ) -> Tuple[torch.Tensor, Optional[KVCache]]: + B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + _C = C // LOCAL_WORLD_SIZE + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.c_attn_q(x) + k = self.c_attn_k(x) + v = self.c_attn_v(x) + + n_head = self.n_head // LOCAL_WORLD_SIZE + head_size = _C // n_head + k = k.view(B, T, n_head, head_size) + q = q.view(B, T, n_head, head_size) + v = v.view(B, T, n_head, head_size) + + q = apply_rope(q, rope) + k = apply_rope(k, rope) + + k = k.transpose(1, 2) # (B, nh, T, hs) + q = q.transpose(1, 2) # (B, nh, T, hs) + v = v.transpose(1, 2) # (B, nh, T, hs) + + if kv_cache is not None: + k, v = kv_cache.update(input_pos, k, v) + + # efficient attention using Flash Attention CUDA kernels + # y = F.scaled_dot_product_attention(q, k, v) + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = y.transpose(1, 2).contiguous().view(B, T, _C) # re-assemble all head outputs side by side + + # output projection + y = self.c_proj(y) + + return y, kv_cache + + def prepare_qkv_for_dtensor_tp(self): + attn = self.c_attn + + assert attn.in_features % LOCAL_WORLD_SIZE == 0 # q, k, v must be shardeable + attn.out_features = attn.out_features // LOCAL_WORLD_SIZE + # Shard on dim 0 since attn.weight is transposed + # Shard q, k, v separately + q, k, v = attn.weight.split(self.config.n_embd, dim=0) # (C, C) + + self.c_attn_q = nn.Linear(self.config.n_embd, self.config.n_embd, bias=False) + self.c_attn_q.weight = nn.Parameter(q) + self.c_attn_k = nn.Linear(self.config.n_embd, self.config.n_embd, bias=False) + self.c_attn_k.weight = nn.Parameter(k) + self.c_attn_v = nn.Linear(self.config.n_embd, self.config.n_embd, bias=False) + self.c_attn_v.weight = nn.Parameter(v) + + del self.c_attn + + +class MLP(nn.Module): + def __init__(self, config: LLaMAConfig) -> None: + super().__init__() + hidden_dim = 4 * config.n_embd + n_hidden = int(2 * hidden_dim / 3) + n_hidden = find_multiple(n_hidden, 256) + + self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False) + self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = F.silu(self.c_fc1(x)) * self.c_fc2(x) + x = self.c_proj(x) + return x + + +class RMSNorm(nn.Module): + """Root Mean Square Layer Normalization. + + Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License: + https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE. + """ + + def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None: + super().__init__() + self.scale = nn.Parameter(torch.ones(size)) + self.eps = eps + self.dim = dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # NOTE: the original RMSNorm paper implementation is not equivalent + # norm_x = x.norm(2, dim=self.dim, keepdim=True) + # rms_x = norm_x * d_x ** (-1. / 2) + # x_normed = x / (rms_x + self.eps) + norm_x = torch.mean(x * x, dim=self.dim, keepdim=True) + x_normed = x * torch.rsqrt(norm_x + self.eps) + return self.scale * x_normed + + +def build_rope_cache( + seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000 +) -> RoPECache: + """Enhanced Transformer with Rotary Position Embedding. + + Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ + transformers/rope/__init__.py. MIT License: + https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. + """ + # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ + theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem)) + + # Create position indexes `[0, 1, ..., seq_len - 1]` + seq_idx = torch.arange(seq_len, dtype=dtype, device=device) + + # Calculate the product of position index and $\theta_i$ + idx_theta = torch.outer(seq_idx, theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + + # this is to mimic the behaviour of complex32, else we will get different results + if dtype in (torch.float16, torch.bfloat16, torch.int8): + cache = cache.half() + return cache + + +def apply_rope(x: torch.Tensor, rope_cache: RoPECache) -> torch.Tensor: + # truncate to support variable sizes + T = x.size(1) + rope_cache = rope_cache[:T] + + # cast because the reference does + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) + rope_cache = rope_cache.view(1, xshaped.size(1), 1, xshaped.size(3), 2) + x_out2 = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + x_out2 = x_out2.flatten(3) + return x_out2.type_as(x) diff --git a/torchbenchmark/models/simple_gpt/origin b/torchbenchmark/models/simple_gpt/origin new file mode 100644 index 0000000000..11465ed49d --- /dev/null +++ b/torchbenchmark/models/simple_gpt/origin @@ -0,0 +1 @@ +https://github.com/pytorch-labs/simple_gpt/ diff --git a/torchbenchmark/models/simple_gpt/requirements.txt b/torchbenchmark/models/simple_gpt/requirements.txt new file mode 100644 index 0000000000..c0605af1f8 --- /dev/null +++ b/torchbenchmark/models/simple_gpt/requirements.txt @@ -0,0 +1 @@ +lightning diff --git a/torchbenchmark/models/simple_gpt/utils.py b/torchbenchmark/models/simple_gpt/utils.py new file mode 100644 index 0000000000..7279408e3e --- /dev/null +++ b/torchbenchmark/models/simple_gpt/utils.py @@ -0,0 +1,5 @@ +import os + +# provided by torchrun +LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) +LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) From f365e24509f9067e6b3ca8561cb495a553e1b1ca Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Fri, 1 Sep 2023 14:05:13 -0700 Subject: [PATCH 2/4] Rework model to be launched via dynamo runner --- torchbenchmark/models/simple_gpt/__init__.py | 23 +++++-------- torchbenchmark/models/simple_gpt/model.py | 35 +++++++++++--------- torchbenchmark/models/simple_gpt/utils.py | 5 --- torchbenchmark/util/extra_args.py | 6 ++++ 4 files changed, 34 insertions(+), 35 deletions(-) delete mode 100644 torchbenchmark/models/simple_gpt/utils.py diff --git a/torchbenchmark/models/simple_gpt/__init__.py b/torchbenchmark/models/simple_gpt/__init__.py index 8eaa59165b..5604766bdd 100644 --- a/torchbenchmark/models/simple_gpt/__init__.py +++ b/torchbenchmark/models/simple_gpt/__init__.py @@ -9,7 +9,6 @@ from ...util.model import BenchmarkModel from .model import LLaMA -from .utils import LOCAL_RANK, LOCAL_WORLD_SIZE class Model(BenchmarkModel): @@ -24,17 +23,17 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): extra_args=extra_args, ) + fabric = L.Fabric(devices=[self._rank], precision="bf16-true") + with fabric.init_module(empty_init=True): + self.model = LLaMA.from_name("7B", self._world_size) + error = self.validate_environment() if error: # per ADDING_MODELS.md, convention is to fail silently in __init__ and raise in eval return - fabric = L.Fabric(devices=[LOCAL_RANK], precision="bf16-true") - with fabric.init_module(empty_init=True): - self.model = LLaMA.from_name("7B") - # Tensor parallelism using DTensor - mesh = DeviceMesh("cuda", list(range(LOCAL_WORLD_SIZE))) + mesh = DeviceMesh("cuda", list(range(self._world_size))) for block in self.model.transformer.h: # prepare attention weights to be parallelized block.attn.prepare_qkv_for_dtensor_tp() @@ -54,7 +53,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): tp_mesh_dim=0, ) - max_batch_size = 1 + max_batch_size = self.DEFAULT_EVAL_BSIZE self.model.setup_caches( max_batch_size=max_batch_size, max_seq_length=self.model.config.block_size ) @@ -82,14 +81,15 @@ def validate_environment(self): if not torch.cuda.is_bf16_supported(): return NotImplementedError("Model requires BF16") - if LOCAL_WORLD_SIZE != torch.cuda.device_count(): + if self._world_size != torch.cuda.device_count(): return NotImplementedError( - f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {LOCAL_WORLD_SIZE}." + f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}." ) return None def eval(self): + # Note: Not called by dynamo runner error = self.validate_environment() if error: raise error @@ -97,8 +97,3 @@ def eval(self): with torch.no_grad(): out = self.model(*self.example_inputs) return (out,) - - -if __name__ == "__main__": - model = Model(test="eval", device="cuda") - model.eval() diff --git a/torchbenchmark/models/simple_gpt/model.py b/torchbenchmark/models/simple_gpt/model.py index 5e77c8c082..df9bf041b4 100644 --- a/torchbenchmark/models/simple_gpt/model.py +++ b/torchbenchmark/models/simple_gpt/model.py @@ -12,9 +12,6 @@ import torch.nn as nn from torch.nn import functional as F -from .utils import LOCAL_WORLD_SIZE - - MaskCache = torch.Tensor RoPECache = torch.Tensor @@ -75,12 +72,14 @@ def from_name(cls, name: str) -> Self: } class KVCache(nn.Module): + @torch.no_grad() def __init__(self, max_batch_size, max_seq_length, n_heads, head_size, device='cuda', dtype=torch.bfloat16): super().__init__() cache_shape = (max_batch_size, n_heads, max_seq_length, head_size) self.k_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) self.v_cache = torch.nn.Parameter(torch.zeros(cache_shape, device=device, dtype=dtype)) + @torch.no_grad() def update(self, input_pos, k_val, v_val): # input_pos: [S], k_val: [B, H, S, D] assert input_pos.shape[0] == k_val.shape[2] @@ -106,8 +105,10 @@ def clear(self): self.kv_caches = nn.ParameterList([]) class LLaMA(nn.Module): - def __init__(self, config: LLaMAConfig) -> None: + def __init__(self, config: LLaMAConfig, world_size: int) -> None: super().__init__() + self.world_size = world_size + assert config.padded_vocab_size is not None self.config = config @@ -115,7 +116,7 @@ def __init__(self, config: LLaMAConfig) -> None: self.transformer = nn.ModuleDict( dict( wte=nn.Embedding(config.padded_vocab_size, config.n_embd), - h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), + h=nn.ModuleList(Block(config, self.world_size) for _ in range(config.n_layer)), ln_f=RMSNorm(config.n_embd), ) ) @@ -127,8 +128,8 @@ def __init__(self, config: LLaMAConfig) -> None: self.max_seq_length = None def setup_caches(self, max_batch_size, max_seq_length, device='cuda', dtype=torch.bfloat16): - n_embd = self.config.n_embd // LOCAL_WORLD_SIZE - n_head = self.config.n_head // LOCAL_WORLD_SIZE + n_embd = self.config.n_embd // self.world_size + n_head = self.config.n_head // self.world_size head_size = n_embd // n_head self.max_seq_length = max_seq_length @@ -182,18 +183,18 @@ def forward( return logits @classmethod - def from_name(cls, name: str) -> Self: - return cls(LLaMAConfig.from_name(name)) + def from_name(cls, name: str, world_size: int) -> Self: + return cls(LLaMAConfig.from_name(name), world_size) def reset_cache(self) -> None: self.kv_caches.clear() class Block(nn.Module): - def __init__(self, config: LLaMAConfig) -> None: + def __init__(self, config: LLaMAConfig, world_size: int) -> None: super().__init__() self.rms_1 = RMSNorm(config.n_embd) - self.attn = CausalSelfAttention(config) + self.attn = CausalSelfAttention(config, world_size) self.rms_2 = RMSNorm(config.n_embd) self.mlp = MLP(config) @@ -213,8 +214,10 @@ def forward( class CausalSelfAttention(nn.Module): - def __init__(self, config: LLaMAConfig) -> None: + def __init__(self, config: LLaMAConfig, world_size: int) -> None: super().__init__() + self.world_size = world_size + assert config.n_embd % config.n_head == 0 self.config = config @@ -237,14 +240,14 @@ def forward( kv_cache: Optional[KVCache] = None, ) -> Tuple[torch.Tensor, Optional[KVCache]]: B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) - _C = C // LOCAL_WORLD_SIZE + _C = C // self.world_size # calculate query, key, values for all heads in batch and move head forward to be the batch dim q = self.c_attn_q(x) k = self.c_attn_k(x) v = self.c_attn_v(x) - n_head = self.n_head // LOCAL_WORLD_SIZE + n_head = self.n_head // self.world_size head_size = _C // n_head k = k.view(B, T, n_head, head_size) q = q.view(B, T, n_head, head_size) @@ -274,8 +277,8 @@ def forward( def prepare_qkv_for_dtensor_tp(self): attn = self.c_attn - assert attn.in_features % LOCAL_WORLD_SIZE == 0 # q, k, v must be shardeable - attn.out_features = attn.out_features // LOCAL_WORLD_SIZE + assert attn.in_features % self.world_size == 0 # q, k, v must be shardeable + attn.out_features = attn.out_features // self.world_size # Shard on dim 0 since attn.weight is transposed # Shard q, k, v separately q, k, v = attn.weight.split(self.config.n_embd, dim=0) # (C, C) diff --git a/torchbenchmark/models/simple_gpt/utils.py b/torchbenchmark/models/simple_gpt/utils.py deleted file mode 100644 index 7279408e3e..0000000000 --- a/torchbenchmark/models/simple_gpt/utils.py +++ /dev/null @@ -1,5 +0,0 @@ -import os - -# provided by torchrun -LOCAL_RANK = int(os.environ.get("LOCAL_RANK", "0")) -LOCAL_WORLD_SIZE = int(os.environ.get("LOCAL_WORLD_SIZE", "1")) diff --git a/torchbenchmark/util/extra_args.py b/torchbenchmark/util/extra_args.py index e9e2f8b539..b5161cb0a9 100644 --- a/torchbenchmark/util/extra_args.py +++ b/torchbenchmark/util/extra_args.py @@ -125,10 +125,16 @@ def apply_decoration_args(model: 'torchbenchmark.util.model.BenchmarkModel', dar def parse_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', opt_args: List[str]) -> argparse.Namespace: parser = argparse.ArgumentParser() parser.add_argument("--backend", choices=list_backends(), help="enable backends") + parser.add_argument("--rank", help="rank of current process") + parser.add_argument("--world_size", help="world size of multiprocess") args, extra_args = parser.parse_known_args(opt_args) if args.backend: backend = BACKENDS[args.backend] model._enable_backend, extra_args = backend(model, backend_args=extra_args) + if args.rank: + model._rank = int(args.rank) + if args.world_size: + model._world_size = int(args.world_size) return args, extra_args def apply_opt_args(model: 'torchbenchmark.util.model.BenchmarkModel', args: argparse.Namespace): From 6efde91707f77fbec4985afcb2532a4c45dc2365 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 5 Sep 2023 15:31:18 -0700 Subject: [PATCH 3/4] remove lightning dep, use default model param values --- torchbenchmark/models/simple_gpt/__init__.py | 5 +---- torchbenchmark/models/simple_gpt/install.py | 9 --------- torchbenchmark/models/simple_gpt/requirements.txt | 1 - 3 files changed, 1 insertion(+), 14 deletions(-) delete mode 100644 torchbenchmark/models/simple_gpt/install.py delete mode 100644 torchbenchmark/models/simple_gpt/requirements.txt diff --git a/torchbenchmark/models/simple_gpt/__init__.py b/torchbenchmark/models/simple_gpt/__init__.py index 5604766bdd..37c089c844 100644 --- a/torchbenchmark/models/simple_gpt/__init__.py +++ b/torchbenchmark/models/simple_gpt/__init__.py @@ -1,7 +1,6 @@ import os import torch -import lightning as L from torch.distributed._tensor import DeviceMesh from torch.distributed.tensor.parallel import parallelize_module from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel @@ -23,9 +22,7 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): extra_args=extra_args, ) - fabric = L.Fabric(devices=[self._rank], precision="bf16-true") - with fabric.init_module(empty_init=True): - self.model = LLaMA.from_name("7B", self._world_size) + self.model = LLaMA.from_name("7B", self._world_size).to(device=device, dtype=torch.bfloat16) error = self.validate_environment() if error: diff --git a/torchbenchmark/models/simple_gpt/install.py b/torchbenchmark/models/simple_gpt/install.py deleted file mode 100644 index be308ead48..0000000000 --- a/torchbenchmark/models/simple_gpt/install.py +++ /dev/null @@ -1,9 +0,0 @@ -import subprocess -import sys - - -def pip_install_requirements(): - subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', '-r', 'requirements.txt']) - -if __name__ == '__main__': - pip_install_requirements() diff --git a/torchbenchmark/models/simple_gpt/requirements.txt b/torchbenchmark/models/simple_gpt/requirements.txt deleted file mode 100644 index c0605af1f8..0000000000 --- a/torchbenchmark/models/simple_gpt/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -lightning From 0de95249653a872e7a65bc5a36c3bb351487f989 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Tue, 5 Sep 2023 18:47:02 -0700 Subject: [PATCH 4/4] raise NotImplementedError from __init__ and update convention in wiki --- torchbenchmark/models/ADDING_MODELS.md | 4 +- torchbenchmark/models/simple_gpt/__init__.py | 47 +++++++++----------- 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/torchbenchmark/models/ADDING_MODELS.md b/torchbenchmark/models/ADDING_MODELS.md index 9681fe49ec..8249fc176a 100644 --- a/torchbenchmark/models/ADDING_MODELS.md +++ b/torchbenchmark/models/ADDING_MODELS.md @@ -95,8 +95,8 @@ This file should define two things: - `__main__` function, which exercises the model APIs for local testing Important: be deliberate about support for cpu/gpu and jit/no-jit. In the case that -your model is instantiated in an unsupported configuration, the convention is to return -a model object from \_\_init\_\_ but raise NotImplementedError() from all its methods. +your model is instantiated in an unsupported configuration, the convention is to raise +NotImplementedError from \_\_init\_\_. See the [BenchmarkModel API](https://github.com/pytorch/benchmark/blob/master/torchbenchmark/util/model.py) to get started. The [BERT_pytorch](BERT_pytorch/__init__.py) benchmark can serve as a good example. diff --git a/torchbenchmark/models/simple_gpt/__init__.py b/torchbenchmark/models/simple_gpt/__init__.py index 37c089c844..de3f43c63f 100644 --- a/torchbenchmark/models/simple_gpt/__init__.py +++ b/torchbenchmark/models/simple_gpt/__init__.py @@ -14,6 +14,23 @@ class Model(BenchmarkModel): task = NLP.GENERATION DEFAULT_EVAL_BSIZE = 1 + def validate_environment(self): + if not torch.cuda.is_available() or "cuda" not in self.device: + return NotImplementedError("Model requires CUDA") + + if not torch.cuda.is_bf16_supported(): + return NotImplementedError("Model requires BF16") + + if not hasattr(self, "_world_size"): + return NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters") + + if self._world_size != torch.cuda.device_count(): + return NotImplementedError( + f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}" + ) + + return None + def __init__(self, test, device, batch_size=None, extra_args=[]): super().__init__( test=test, @@ -22,12 +39,11 @@ def __init__(self, test, device, batch_size=None, extra_args=[]): extra_args=extra_args, ) - self.model = LLaMA.from_name("7B", self._world_size).to(device=device, dtype=torch.bfloat16) - error = self.validate_environment() if error: - # per ADDING_MODELS.md, convention is to fail silently in __init__ and raise in eval - return + raise error + + self.model = LLaMA.from_name("7B", self._world_size).to(device=device, dtype=torch.bfloat16) # Tensor parallelism using DTensor mesh = DeviceMesh("cuda", list(range(self._world_size))) @@ -71,26 +87,5 @@ def get_module(self): def train(self): raise NotImplementedError("Training not supported for this model") - def validate_environment(self): - if not torch.cuda.is_available() or "cuda" not in self.device: - return NotImplementedError("Model requires CUDA") - - if not torch.cuda.is_bf16_supported(): - return NotImplementedError("Model requires BF16") - - if self._world_size != torch.cuda.device_count(): - return NotImplementedError( - f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}." - ) - - return None - def eval(self): - # Note: Not called by dynamo runner - error = self.validate_environment() - if error: - raise error - - with torch.no_grad(): - out = self.model(*self.example_inputs) - return (out,) + raise NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")