-
Notifications
You must be signed in to change notification settings - Fork 552
gpt-oss model enablement #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
wwwjn
wants to merge
18
commits into
main
Choose a base branch
from
gpt-oss
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+1,453
−10
Open
gpt-oss model enablement #1754
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
0313a6f
gptoss experimental support
aa2db3f
clean up tentative licensing
04ad2c4
training fixes: expert load balancing, TP for sinks + experts, EP wor…
43563ec
only assert sdpa backends if using sdpa; improve conversion script
21c1679
fixed conversion script with param by param
f3ad331
new lse-based flexattn implementation for sinks
64b5c32
test
wwwjn c3036c8
rebase
wwwjn c02acaf
fix flexattn
wwwjn 1bfdfbb
check and replace rope
wwwjn 29cc72f
FSDP work, TP doesn't work
wwwjn b841303
test
wwwjn 9697286
fix sink
wwwjn daf5a6e
test EP
wwwjn 8aa281d
working on ETP
wwwjn d9d0b05
clean up
wwwjn 10869da
clean up
wwwjn 8bfbf7c
rebase + address comments
wwwjn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# gpt-oss Model in torchtitan | ||
|
||
## Quick Start | ||
```bash | ||
CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh | ||
``` | ||
|
||
## Supported Features | ||
- FSDP/HSDP, TP, EP, ETP | ||
- Grouped matrix multiplication for efficient computation | ||
- SwiGLU activation | ||
- Multi-head attention with sliding window mask and attention sink | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
## TODO | ||
1. More parallelism support: CP, PP | ||
2. Conversion between HF weights (StateDictAdapter) | ||
3. Forward parity verification | ||
4. CI support |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# This source code is licensed under the BSD-style license found in the | ||
wwwjn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# LICENSE file in the root directory of this source tree. | ||
|
||
from torchtitan.components.loss import build_cross_entropy_loss | ||
from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
from torchtitan.components.tokenizer import build_hf_tokenizer | ||
from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
from torchtitan.models.moe import MoEArgs | ||
|
||
from torchtitan.protocols.train_spec import TrainSpec | ||
|
||
from .infra.parallelize import parallelize_gptoss | ||
from .model.args import GptOssModelArgs | ||
from .model.model import GptOssModel | ||
|
||
__all__ = [ | ||
"parallelize_gptoss", | ||
"GptOssModelArgs", | ||
"GptOssModel", | ||
"gptoss_configs", | ||
] | ||
|
||
|
||
gptoss_configs = { | ||
"debugmodel": GptOssModelArgs( | ||
dim=256, | ||
n_layers=4, | ||
moe_args=MoEArgs( | ||
num_experts=8, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
attn_mask_type="causal", | ||
), | ||
"20b": GptOssModelArgs( | ||
n_layers=24, | ||
moe_args=MoEArgs( | ||
num_experts=32, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
), | ||
"120b": GptOssModelArgs( | ||
n_layers=36, | ||
moe_args=MoEArgs( | ||
num_experts=128, | ||
num_shared_experts=0, | ||
score_func="softmax", | ||
route_norm=False, | ||
route_scale=1.0, | ||
score_before_experts=False, | ||
top_k=4, | ||
use_grouped_mm=True, | ||
load_balance_coeff=1e-3, | ||
), | ||
), | ||
} | ||
|
||
|
||
def get_train_spec() -> TrainSpec: | ||
return TrainSpec( | ||
name="gpt_oss", | ||
model_cls=GptOssModel, | ||
model_args=gptoss_configs, | ||
parallelize_fn=parallelize_gptoss, | ||
pipelining_fn=None, | ||
build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
build_lr_schedulers_fn=build_lr_schedulers, | ||
build_dataloader_fn=build_hf_dataloader, | ||
build_tokenizer_fn=build_hf_tokenizer, | ||
build_loss_fn=build_cross_entropy_loss, | ||
) |
194 changes: 194 additions & 0 deletions
194
torchtitan/experiments/gpt_oss/infra/expert_parallel.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,194 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from typing import Callable | ||
|
||
import torch | ||
import torch.nn as nn | ||
from torch.distributed.tensor import ( | ||
DeviceMesh, | ||
distribute_module, | ||
distribute_tensor, | ||
DTensor, | ||
Replicate, | ||
Shard, | ||
) | ||
from torch.distributed.tensor.parallel import ParallelStyle | ||
from torchtitan.distributed.expert_parallel import ExpertParallel | ||
|
||
|
||
# implementation of Tensor Parallel for the GroupedExperts in MoE | ||
class TensorParallel(ParallelStyle): | ||
def _partition_fn(self, name, module, device_mesh): | ||
module.register_parameter( | ||
"mlp1_weight", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp1_weight, device_mesh, [Shard(2)]) | ||
), | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), | ||
) # Column-wise sharding | ||
module.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp2_weight, device_mesh, [Shard(1)]) | ||
), | ||
) # Row-wise sharding | ||
module.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter( | ||
distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) | ||
), | ||
) # Replicate | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
self._partition_fn, | ||
) | ||
|
||
|
||
# This class is for dp2ep with TP (without TP we can just use ExpertParallel) | ||
class ExpertTensorParallel(ExpertParallel): | ||
def __init__( | ||
self, | ||
tp_mesh: DeviceMesh, | ||
ep_mesh: DeviceMesh, | ||
): | ||
super().__init__() | ||
# TODO: has to pass in the meshes in addition to the [ep, tp] device_mesh, | ||
# as DeviceMesh doesn't support slicing from a submesh. | ||
self.tp_mesh = tp_mesh | ||
self.ep_mesh = ep_mesh | ||
|
||
def _token_dispatch(self, mod, inputs, device_mesh): | ||
# token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_dispatch(mod, inputs, self.ep_mesh) | ||
|
||
def _partition_fn_2d(self, name, mod, ep_tp_mesh): | ||
mod.register_parameter( | ||
"mlp1_weight", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(2)]) | ||
), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp1_bias", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
), | ||
) # Column-wise sharding | ||
mod.register_parameter( | ||
"mlp2_weight", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
), | ||
) # Row-wise sharding | ||
mod.register_parameter( | ||
"mlp2_bias", | ||
nn.Parameter( | ||
distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) | ||
), | ||
) # Replicate | ||
|
||
def _token_combine(self, mod, routed_output, device_mesh): | ||
# token combine happens on the EP mesh, whereas device_mesh is [ep, tp] mesh | ||
return super()._token_combine(mod, routed_output, self.ep_mesh) | ||
|
||
def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: | ||
return distribute_module( | ||
module, | ||
device_mesh, | ||
partition_fn=self._partition_fn_2d, | ||
input_fn=self._token_dispatch, | ||
output_fn=self._token_combine, | ||
) | ||
|
||
|
||
# TODO(jianiw): This need to be merged with expert_parallel | ||
def expert_parallel(func: Callable) -> Callable: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry I'll merge my refactor, and then please rebase There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you referring to #1569 ? |
||
""" | ||
This is a wrapper applied to the GroupedExperts computation, serving | ||
the following three purposes: | ||
1. Convert parameters from DTensors to plain Tensors, to work with | ||
dynamic-shape inputs which cannot be easily expressed as DTensors. | ||
2. In Expert Parallel, apply the generate_permute_indices kernel to | ||
permute the inputs to be ordered by local experts (see the _token_dispatch | ||
function in ExpertParallel) and permute the outputs back. | ||
3. In order to use torch._grouped_mm, we need to make sure the number of | ||
tokens each expert gets is a multiple of ALIGN_SIZE_M. The generate_permute_indices | ||
kernel also helps achieve this via padding, without incurring synchronization | ||
between device and host. Note that this will create side effects when wrapping | ||
the for-loop implementation of GroupedExperts, as it does not need padding. | ||
|
||
Among the above: | ||
1 and 2 are needed only when expert_parallel_degree > 1. | ||
3 is needed even for single-device computation. | ||
2 can be moved to ExpertParallel _token_dispatch if not coupled with 3. | ||
""" | ||
|
||
def wrapper( | ||
mlp1_weight: torch.Tensor, | ||
mlp1_bias: torch.Tensor, | ||
mlp2_weight: torch.Tensor, | ||
mlp2_bias: torch.Tensor, | ||
swiglu_limit: float, | ||
x: torch.Tensor, | ||
num_tokens_per_expert: torch.Tensor | None = None, | ||
) -> torch.Tensor: | ||
if isinstance(mlp1_weight, DTensor): | ||
mlp1_weight = mlp1_weight.to_local() | ||
mlp1_bias = mlp1_bias.to_local() | ||
mlp2_weight = mlp2_weight.to_local() | ||
mlp2_bias = mlp2_bias.to_local() | ||
|
||
if num_tokens_per_expert is not None: | ||
from torchtitan.experiments.kernels.moe.indices import ( | ||
generate_permute_indices, | ||
) | ||
|
||
experts_per_ep_rank = mlp1_weight.shape[0] | ||
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank | ||
|
||
ALIGN_SIZE_M = 16 | ||
with torch.no_grad(): | ||
( | ||
permuted_indices, | ||
num_tokens_per_expert, | ||
_, # offsets, | ||
) = generate_permute_indices( | ||
num_tokens_per_expert, | ||
experts_per_ep_rank, | ||
num_ep_ranks, | ||
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M, | ||
ALIGN_SIZE_M, | ||
) | ||
|
||
x = torch.vstack((x, x.new_zeros((x.shape[-1])))) | ||
input_shape = x.shape | ||
x = x[permuted_indices, :] | ||
|
||
out = func( | ||
mlp1_weight, | ||
mlp1_bias, | ||
mlp2_weight, | ||
mlp2_bias, | ||
swiglu_limit, | ||
x, | ||
num_tokens_per_expert, | ||
) | ||
|
||
if num_tokens_per_expert is not None: | ||
out_unpermuted = out.new_empty(input_shape) | ||
out_unpermuted[permuted_indices, :] = out | ||
out = out_unpermuted[:-1] | ||
|
||
return out | ||
|
||
return wrapper |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.