diff --git a/torchtitan/experiments/deepseek_v3/model.py b/torchtitan/experiments/deepseek_v3/model.py index a22666d86d..1b6e19ecd7 100644 --- a/torchtitan/experiments/deepseek_v3/model.py +++ b/torchtitan/experiments/deepseek_v3/model.py @@ -453,6 +453,8 @@ class MoE(nn.Module): # 1. "torch_all_to_all" # 2. "symm_mem" (see `setup_symm_mem` below) shuffle_method = "torch_all_to_all" + # Group GEMM method, "torch" or "torchao" + group_mm = "torch" # Symmetric memory buffers shared by all MoE instances across layers token_send_buf: Optional[torch.Tensor] = None @@ -490,15 +492,21 @@ def __init__(self, config): config=config, intermediate_size=intermediate_size ) - def combine_experts(self, submod_name): + def combine_experts(self, submod_name: str): all_weights = [] for expert in self.experts.values(): lin = expert.get_submodule(submod_name) all_weights.append(lin.weight) lin.weight = None - concat_weight = torch.cat(all_weights) - self.register_parameter(f"{submod_name}_weight", nn.Parameter(concat_weight)) + if self.group_mm == "torch": + combined_weight = torch.stack(all_weights) + elif self.group_mm == "torchao": + combined_weight = torch.cat(all_weights) + else: + raise RuntimeError(f"Unknown Group GEMM method: {self.group_mm}") + + self.register_parameter(f"{submod_name}_weight", nn.Parameter(combined_weight)) # This function is used to create a symm mem buffer for MoE's. It is for # shuffling tokens fully "on-device", as compared to traditional torch @@ -510,7 +518,6 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): self.shuffle_method = "symm_mem" # Combine expert weights - print("Combining expert weights for Group GEMM") self.combine_experts("gate_proj") self.combine_experts("up_proj") self.combine_experts("down_proj") @@ -544,6 +551,7 @@ def setup_symm_mem(self, dtype: torch.dtype, device: torch.device): device=device, ) print(f"EP rank [{self.ep_rank}]: Created Symmetric Memory for MoE") + print("Combining expert weights for Group GEMM") def get_send_buf(self): # [Why detach?] During a first forward-backward step, the buffer would @@ -735,7 +743,7 @@ def moe_on_device(self, x, topk_ids, topk_weight): token_send_buf = self.get_send_buf() token_send_buf[: idxs.shape[0]].copy_(sorted_tokens) # Note: `out=` avoids copy, but it is not differentiable - # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=self.token_send_buf[: idxs.shape[0]]) + # torch.index_select(x, 0, idxs // topk_ids.shape[1], out=token_send_buf[: idxs.shape[0]]) token_gather_buf, output_splits = OnDeviceAllToAllV.apply( token_send_buf, input_splits, @@ -746,7 +754,7 @@ def moe_on_device(self, x, topk_ids, topk_weight): # This part prepares a 1D tensor `permuted_indices` for such permutation. # This part doesn't need gradient. with torch.no_grad(): - permuted_indices, m_sizes = generate_permute_indices( + permuted_indices, m_sizes, m_offsets = generate_permute_indices( tokens_per_expert_group, self.experts_per_rank, self.ep_size, @@ -759,18 +767,36 @@ def moe_on_device(self, x, topk_ids, topk_weight): # Run the first grouped GEMM w1 = self.get_parameter("gate_proj_weight") - gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes) + if self.group_mm == "torchao": + gate_proj = grouped_gemm_forward(contig_tokens, w1, m_sizes) + else: # "torch" + gate_proj = torch._grouped_mm( + contig_tokens, w1.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16 + ) # Run the second grouped GEMM w3 = self.get_parameter("up_proj_weight") - up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes) + if self.group_mm == "torchao": + up_proj = grouped_gemm_forward(contig_tokens, w3, m_sizes) + else: # "torch" + up_proj = torch._grouped_mm( + contig_tokens, w3.transpose(-2, -1), m_offsets, out_dtype=torch.bfloat16 + ) # Apply activation hidden_outputs = MLP.act_fn(gate_proj) * up_proj # Run the third grouped GEMM w2 = self.get_parameter("down_proj_weight") - hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes) + if self.group_mm == "torchao": + hidden_outputs = grouped_gemm_forward(hidden_outputs, w2, m_sizes) + else: # "torch" + hidden_outputs = torch._grouped_mm( + hidden_outputs, + w2.transpose(-2, -1), + m_offsets, + out_dtype=torch.bfloat16, + ) # Prepare buffer for tokens processed by experts # Take necessary space from `token_gather_buf` symm mem because we are diff --git a/torchtitan/experiments/kernels/moe/indices.py b/torchtitan/experiments/kernels/moe/indices.py index 39d5946ece..be4688a283 100644 --- a/torchtitan/experiments/kernels/moe/indices.py +++ b/torchtitan/experiments/kernels/moe/indices.py @@ -139,7 +139,8 @@ def generate_permute_indices( torch.int32 ) # Perform another prefix sum to get the write offset of each expert in `permuted_indices` - write_offsets = torch.cumsum(m_sizes, 0) - m_sizes + m_offsets = torch.cumsum(m_sizes, 0) + write_offsets = m_offsets - m_sizes # Select the method to fill the permuted indices fill_fn = fill_indices_cpu if use_cpu else fill_indices # Fill the permuted indices @@ -151,7 +152,7 @@ def generate_permute_indices( num_ranks, max_len, ) - return permuted_indices, m_sizes + return permuted_indices, m_sizes, m_offsets.to(torch.int32) # Below is for testing only @@ -167,11 +168,11 @@ def test(): max_len = 128 alignment = 32 # Use the GPU kernel - permuted_indices_gpu, m_sizes = generate_permute_indices( + permuted_indices_gpu, m_sizes, _ = generate_permute_indices( tokens_per_expert_group, experts_per_rank, num_ranks, max_len, alignment ) # Use the CPU method - permuted_indices_cpu, _ = generate_permute_indices( + permuted_indices_cpu, _, _ = generate_permute_indices( tokens_per_expert_group, experts_per_rank, num_ranks,