Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions torchtitan/experiments/deepseek_v3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,8 @@ class MoE(nn.Module):
# 1. "torch_all_to_all"
# 2. "symm_mem" (see `setup_symm_mem` below)
shuffle_method = "torch_all_to_all"
# Group GEMM method, "torch" or "torchao"
group_mm = "torch"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tiny nit but I like just having selectable options next to the item, like how we do it in the toml:

group_mm = "torch"  # ["torch", "torchao"]


# Symmetric memory buffers shared by all MoE instances across layers
token_send_buf: Optional[torch.Tensor] = None
Expand Down Expand Up @@ -490,15 +492,21 @@ def __init__(self, config):
config=config, intermediate_size=intermediate_size
)

def combine_experts(self, submod_name):
def combine_experts(self, submod_name: str):
all_weights = []
for expert in self.experts.values():
lin = expert.get_submodule(submod_name)
all_weights.append(lin.weight)
lin.weight = None

concat_weight = torch.cat(all_weights)
self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight))
if self.group_mm == "torch":
combined_weight = torch.stack(all_weights)
elif self.group_mm == "torchao":
combined_weight = torch.cat(all_weights)
else:
raise RuntimeError(f"Unknown Group GEMM method: {self.group_mm}")

self.register_parameter(f"{submod_name}_weight", nn.Parameter(combined_weight))

# This function is used to create a symm mem buffer for MoE's. It is for
# shuffling tokens fully "on-device", as compared to traditional torch
Expand All @@ -510,7 +518,6 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
self.shuffle_method = "symm_mem"

# Combine expert weights
print("Combining expert weights for Group GEMM")
self.combine_experts("gate_proj")
self.combine_experts("up_proj")
self.combine_experts("down_proj")
Expand Down Expand Up @@ -544,6 +551,7 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device):
device=device,
)
print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE")
print("Combining expert weights for Group GEMM")

def get_send_buf(self):
# [Why detach?] During a first forward-backward step, the buffer would
Expand Down Expand Up @@ -735,7 +743,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
token_send_buf = self.get_send_buf()
token_send_buf[: idxs.shape[0]].copy_(sorted_tokens)
# Note: `out=` avoids copy, but it is not differentiable
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]])
# torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]])
token_gather_buf, output_splits = OnDeviceAllToAllV.apply(
token_send_buf,
input_splits,
Expand All @@ -746,7 +754,7 @@ def moe_on_device(self, x, topk_ids, topk_weight):
# This part prepares a 1D tensor `permuted_indices` for such permutation.
# This part doesn't need gradient.
with torch.no_grad():
permuted_indices, m_sizes = generate_permute_indices(
permuted_indices, m_sizes, m_offsets = generate_permute_indices(
tokens_per_expert_group,
self.experts_per_rank,
self.ep_size,
Expand All @@ -759,18 +767,36 @@ def moe_on_device(self, x, topk_ids, topk_weight):

# Run the first grouped GEMM
w1 = self.get_parameter("gate_proj_weight")
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
if self.group_mm == "torchao":
gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes)
else: # "torch"
gate_proj = torch._grouped_mm(
contig_tokens, w1.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16
)

# Run the second grouped GEMM
w3 = self.get_parameter("up_proj_weight")
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
if self.group_mm == "torchao":
up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes)
else: # "torch"
up_proj = torch._grouped_mm(
contig_tokens, w3.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16
)

# Apply activation
hidden_outputs = MLP.act_fn(gate_proj) * up_proj

# Run the third grouped GEMM
w2 = self.get_parameter("down_proj_weight")
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
if self.group_mm == "torchao":
hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes)
else: # "torch"
hidden_outputs = torch._grouped_mm(
hidden_outputs,
w2.transpose(-2, -1),
m_offsets,
out_dtype=torch.bfloat16,
)

# Prepare buffer for tokens processed by experts
# Take necessary space from `token_gather_buf` symm mem because we are
Expand Down
9 changes: 5 additions & 4 deletions torchtitan/experiments/kernels/moe/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ def generate_permute_indices(
torch.int32
)
# Perform another prefix sum to get the write offset of each expert in `permuted_indices`
write_offsets = torch.cumsum(m_sizes, 0) - m_sizes
m_offsets = torch.cumsum(m_sizes, 0)
write_offsets = m_offsets - m_sizes
# Select the method to fill the permuted indices
fill_fn = fill_indices_cpu if use_cpu else fill_indices
# Fill the permuted indices
Expand All @@ -151,7 +152,7 @@ def generate_permute_indices(
num_ranks,
max_len,
)
return permuted_indices, m_sizes
return permuted_indices, m_sizes, m_offsets.to(torch.int32)


# Below is for testing only
Expand All @@ -167,11 +168,11 @@ def test():
max_len = 128
alignment = 32
# Use the GPU kernel
permuted_indices_gpu, m_sizes = generate_permute_indices(
permuted_indices_gpu, m_sizes, _ = generate_permute_indices(
tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment
)
# Use the CPU method
permuted_indices_cpu, _ = generate_permute_indices(
permuted_indices_cpu, _, _ = generate_permute_indices(
tokens_per_expert_group,
experts_per_rank,
num_ranks,
Expand Down