In [2]:
# I don't always annotate types because it's often obvious, but this is a relatively subtle 
# implementation where the index book-keeping etc is a bit involved, so I figured it'd be useful here 
import torch, torch.nn as nn, torch.nn.functional as F
from typing import Tuple

# TODO: 
    # allow support for top_k experts 
    # then put it into the train_rnn training loop 
        # we'll have to add z-loss/qk norm/ load balancing loss 

class Router(nn.Module): # [b, s, d] -> [b, s, M] learns to assign each token an expert 
    def __init__(self, d: int = 512, act: nn.Module = nn.GELU(), mult: int = 4, m: int = 16): 
        super().__init__()
        self.w1 = nn.Linear(d, mult * d)
        self.act = act 
        self.w2 = nn.Linear(mult * d, m)

    # [b, s, d] -> [b, s, m] automatically over last time of inputs
    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        return self.w2(self.act(self.w1(x)))


class MoEMLP(nn.Module): 
    def __init__(self, m: int = 16, c: float = 1.25, top_k: int = 2, d: int = 512, mult:int = 4, act: nn.Module = nn.GELU()): 
        super().__init__()
        self.m = m
        self.c = c
        self.top_k = top_k
        self.d = d
        self.mult = mult

        # we use nn.Parameter and not nn.Linear because we want a 3-tensor of params across all experts     
            # you could also use an nn.ModuleList to track the nn.Linear across the m experts 
        self.mlp1 = nn.Parameter(torch.randn(m, d, d*mult) / (d ** 0.5)) # make sure scaled by fan_in so matmul outputs are O(1) indep of d
        self.mlp2 = nn.Parameter(torch.randn(m, d*mult, d) / ((d*mult) ** 0.5)) 
        self.act = act 

        self.router = Router(d=d, act=act, mult=mult, m=m)

    # [b, s, d] -> ([m, cap, d], [N, d], [N])
        # first object is input to MLP computation along last dim 
        # second two tell us how the scatter was done so in gather() we can invert it back to [b, s, d]
    def expert_scatter(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 
        b, s, d = x.shape 
        x_flat = x.view(-1, d) # [b*s, d]
        c, m = self.c, self.m

        cap = int(b*s*c//m) # token capacity for a single token 
        expert_inputs = torch.zeros(m, cap, d, device=x.device) # this is the object we seek to construct using our input x 
        
        indices = torch.argmax(self.router(x), dim=-1).reshape(b*s) # [b, s, d] -> [b, s, m] -> [b, s] -> [b*s]
        # # compute router logits and select top_k experts
        # router_logits = self.router(x) # [b, s, m]
        # # select top_k and get scores
        # topk_logits, topk_indices = torch.topk(router_logits, k, dim=-1) # [b, s, k]
        # topk_logits_flat = topk_logits.view(B, k)
        # topk_indices_flat = topk_indices.view(B, k)
        # # compute weights
        # topk_weights = F.softmax(topk_logits_flat, dim=-1) # [B, k]

        # # expand and weight inputs for each expert assignment
        # weighted_src = x_flat.unsqueeze(1) * topk_weights.unsqueeze(2) # [B, k, d]
        # assign_experts_flat = topk_indices_flat.reshape(-1) # [B*k]
        # weighted_src_flat = weighted_src.reshape(-1, d) # [B*k, d]

        one_hot = F.one_hot(indices, num_classes=m) # [b*s, m], we'll call b*s=B to represent "effective batch size over tokens"
        counts = torch.cumsum(one_hot, dim=0) # [B, m] tells us how many tokens assigned to that expert seen so far (vertically)
        
        # this pos computation is one of the most subtle lines, make sure to understand it
            # it uses clever indexing/broadcasting to take each index to the its index within [cap]
            # which is precisely the count above less one because we're going from a count to a 0-idx
            # here's how the indexing works
                # an important torch rule: indexing tensor X with tensors of shape
                    #  Y outputs a tensor of shape Y
                # eg. X[1, 2, 3] the tensors Y are scalars, so we get a scalar, ie. the 1-2-3th entry in X

                # here we want, for each row in range(b*s), the indices[row] element
                # we want it to be [B] dim, so lets use range(B) and indices as our indices
        pos = counts[torch.arange(b*s), indices] - 1

        mask = pos < cap # these are the tokens we'll process at all, this is a [B] tensor of bools
        selected_experts = indices[mask] # [N], where N is the number of indices that are true 
                                                # in the mask (data dependent)
        
        selected_pos = pos[mask] # [N]
        selected_src = x_flat[mask]

        flat_inputs = expert_inputs.view(m * cap, -1)
        flat_idx = selected_experts * cap + selected_pos # for each selected idx, get its index in the target tensor
        flat_idx = flat_idx.unsqueeze(-1).expand(-1, d) # [N] -> [N, d] so it has same shape as flat_inputs for .scatter_()
        flat_inputs.scatter_(0, flat_idx, selected_src) # in place scatter of rows, everything was building up to be able to do this
        expert_inputs = flat_inputs.view(m, cap, -1) 

        return (expert_inputs, flat_idx, mask) 

    # take scatter indices and the expert-wise outputs [m, cap, d] and 
    def expert_gather(self, outputs: torch.Tensor, mask: torch.Tensor, flat_idx: torch.Tensor, x: torch.Tensor) -> torch.Tensor: 
        b, s, d = x.shape
        C, m = self.c, self.m 
        cap = int(b*s*C//m)

        ## now we have outputs, gather them back into [b, s, d] 
        # gather has a similar api to scatter() and so we can re-use flat_idx to invert scatter()
        flat_outputs = outputs.view(m*cap, d) # [m*cap, d]
        gathered = torch.gather(flat_outputs, 0, flat_idx) # [N, d]
        result = x.clone().reshape(b*s, d) # [b*s, d], we clone so that unprocessed tokens are pushed through 
        result[mask] = gathered # processed tokens overwritten by flat_outputs, which is output of MLPs
        result = result.view(b, s, d) # want output to be [b, s, d], like input in an MLP layer 
        
        return result 

    def forward(self, x: torch.Tensor) -> torch.Tensor: 
        # assign each token to an expert and arrange tokens ready for expert computation
        expert_inputs, flat_idx, mask = self.expert_scatter(x)

        ## now do mlp computation on the inputs with a bmm along the num_experts axis (first, so treated as batch by default)
        expert_outputs = torch.bmm(self.act(torch.bmm(expert_inputs, self.mlp1)), self.mlp2) # [m, cap, d] @ [m, d, d] -> [m, cap, d]

        # use our scatter() indices to gather everything back into place into [b, s, d]
        out = self.expert_gather(expert_outputs, mask, flat_idx, x) # [b, s, d]
        
        return out 

b, s, d, m, c = 16, 128, 256, 4, 1.25
moe = MoEMLP(d=d, m=m, c=c)
x = torch.randn(b, s, d)
print(moe(x).shape)


torch.Size([16, 128, 256])
