Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions torchbenchmark/models/ADDING_MODELS.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
## Detailed steps

### Adding the model code
The intent is to preserve the original user code as much as possible while
The intent is to preserve the original user code as much as possible while
adding support for a standardized interface to the benchmark suite and making sure
the code can run from any directory and in a process with other models.

In many case it is fine to simply copy the entire original repo into a subdirectory
as a starting point, paying attention to avoid the .git folder, and not to add any
as a starting point, paying attention to avoid the .git folder, and not to add any
large unnecessary data files unintentionally. The subdirectory name should be a valid
Python identifier because it will become a module in Python and needs to be importable.

Create a new file 'origin' that contains the url to the git repo you're copying,
Create a new file 'origin' that contains the url to the git repo you're copying,
so it's easy to trace the code back to where it came from.

#### Wrapping your model in \_\_init\_\_.py
Expand All @@ -34,22 +34,22 @@ Take care to set the random seed like [here](https://github.com/pytorch/benchmar
#### A minimal new model addition
A bare miminum example you can follow is https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/phlippe_resnet

The functions you specifically need to implement are
The functions you specifically need to implement are
1. `__init__()` which is responsible for initalizing your `nn.Module`
2. `get_module()` which is responsible for returning the initialized `nn.Module` and an example input
3. `train()` which is a training loop, you can return a `NotImplementedError()` if your example is inference only. If your
training loop can be encapsulated by a `forward()`, `backward()`, and `optimizer_step()`, you need not redefine `train()`.
Instead, please make sure your model provides functions `forward()`, `backward()`, and `optimizer_step()` along with an
attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details.
attribute `self.optimizer` which will be chained together for testing, see `invoke_staged_train_test()` for details.
4. `eval()` which showcases a simple inference

Optionally, if you would like to be able to customize different optimizers for your model, feel free
Optionally, if you would like to be able to customize different optimizers for your model, feel free
to override the BenchmarkModel's base class' default `get_optimizer()` and `set_optimizer(optimizer)`
methods.
methods.

### Preparing install.py and dependencies
Simply put, install.py should be a one stop shop to install all the dependencies
for your model, __except torch, torchvision, torchaudio__ which should be assumed to
for your model, __except torch, torchvision, torchaudio__ which should be assumed to
have been installed by an outsider (the benchmark CI).

- Avoid pinning packages to specific versions with == without good reason, as the
Expand All @@ -65,7 +65,7 @@ not easy to build, there may be easier models to target.
[Example install.py](BERT_pytorch/install.py)

### Mini-dataset
By the time install.py script runs, a miniature version of the dataset is expected to be
By the time install.py script runs, a miniature version of the dataset is expected to be
staged and ready for use. It's fine to use install.py to download and prepare the data
if the download is quick. Otherwise, prepare the dataset manually, checking in the required
artifacts and modifying the \_\_init\_\_.py script as needed to use them.
Expand Down Expand Up @@ -95,8 +95,8 @@ This file should define two things:
- `__main__` function, which exercises the model APIs for local testing

Important: be deliberate about support for cpu/gpu and jit/no-jit. In the case that
your model is instantiated in an unsupported configuration, the convention is to return
a model object from \_\_init\_\_ but raise NotImplementedError() from all its methods.
your model is instantiated in an unsupported configuration, the convention is to raise
NotImplementedError from \_\_init\_\_.

See the [BenchmarkModel API](https://github.com/pytorch/benchmark/blob/master/torchbenchmark/util/model.py) to get started. The [BERT_pytorch](BERT_pytorch/__init__.py) benchmark can serve as a good example.

Expand All @@ -109,11 +109,11 @@ version.

### Test

After you've submitted your new model, suppose it was called `new_model` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models`
After you've submitted your new model, suppose it was called `<new_model>` make sure the tests pass locally. Your model name is equivalent to the new folder you'd have created in `torchbenchmark/models`

1. `cd benchmark`
2. `python install.py`
3. `python run.py model -d cuda` and `python run.py model -d cpu`
3. `python test.py -k "model_"` following the format from here https://github.com/pytorch/benchmark#using-testpy
3. `python run.py <new_model> -d cuda` and `python run.py <new_model> -d cpu`
3. `python test.py -k "test_<new_model>_"` following the format from here https://github.com/pytorch/benchmark#using-testpy

And thank you for contributing to torchbench!
91 changes: 91 additions & 0 deletions torchbenchmark/models/simple_gpt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import os

import torch
from torch.distributed._tensor import DeviceMesh
from torch.distributed.tensor.parallel import parallelize_module
from torch.distributed.tensor.parallel.style import ColwiseParallel, RowwiseParallel
from torchbenchmark.tasks import NLP

from ...util.model import BenchmarkModel
from .model import LLaMA


class Model(BenchmarkModel):
task = NLP.GENERATION
DEFAULT_EVAL_BSIZE = 1

def validate_environment(self):
if not torch.cuda.is_available() or "cuda" not in self.device:
return NotImplementedError("Model requires CUDA")

if not torch.cuda.is_bf16_supported():
return NotImplementedError("Model requires BF16")

if not hasattr(self, "_world_size"):
return NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")

if self._world_size != torch.cuda.device_count():
return NotImplementedError(
f"DTensor and all local GPUs to be within the device mesh. {torch.cuda.device_count()} local GPUs, but only world size is only {self._world_size}"
)

return None

def __init__(self, test, device, batch_size=None, extra_args=[]):
super().__init__(
test=test,
device=device,
batch_size=batch_size,
extra_args=extra_args,
)

error = self.validate_environment()
if error:
raise error

self.model = LLaMA.from_name("7B", self._world_size).to(device=device, dtype=torch.bfloat16)

# Tensor parallelism using DTensor
mesh = DeviceMesh("cuda", list(range(self._world_size)))
for block in self.model.transformer.h:
# prepare attention weights to be parallelized
block.attn.prepare_qkv_for_dtensor_tp()

parallelize_module(
module=block,
device_mesh=mesh,
parallelize_plan={
"attn.c_attn_q": ColwiseParallel(),
"attn.c_attn_k": ColwiseParallel(),
"attn.c_attn_v": ColwiseParallel(),
"attn.c_proj": RowwiseParallel(),
"mlp.c_fc1": ColwiseParallel(),
"mlp.c_fc2": ColwiseParallel(),
"mlp.c_proj": RowwiseParallel(),
},
tp_mesh_dim=0,
)

max_batch_size = self.DEFAULT_EVAL_BSIZE
self.model.setup_caches(
max_batch_size=max_batch_size, max_seq_length=self.model.config.block_size
)

prompt_size = 10
idx = torch.randint(
self.model.config.vocab_size,
(max_batch_size, prompt_size),
dtype=torch.int32,
device=device,
)
input_pos = torch.arange(prompt_size, device=device)
self.example_inputs = [idx, input_pos]

def get_module(self):
return self.model, self.example_inputs

def train(self):
raise NotImplementedError("Training not supported for this model")

def eval(self):
raise NotImplementedError("Model needs to be run via dynamo torchbench and be provided distributed parameters")
5 changes: 5 additions & 0 deletions torchbenchmark/models/simple_gpt/metadata.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
eval_benchmark: false
eval_deterministic: false
eval_nograd: true
train_benchmark: false
train_deterministic: false
Loading