# Implementing Tok-k Importance Routing 

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Expert(nn.Module):
    def __init__(self, d_model, d_hidden, activation, dropout):
        super().__init__()
        act = {"relu":nn.ReLU(), "gelu": nn.GELU()}[activation]
        self.ffn = nn.Sequential(
                nn.Linear(d_model, d_hidden),
                act,
                nn.Dropout(dropout),
                nn.Linear(d_hidden, d_model)
            )

    def forward(self, x):
        return self.ffn(x)

### MoE Layer With Importance Routing

In [None]:
class MoE(nn.Module):
    def __init__(self, d_model: int, d_hidden: int, num_experts: int, activation: str, dropout: float, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k

        # Defining the router 
        self.router = nn.Linear(d_model, num_experts)

        # Pool of experts 
        self.experts = nn.ModuleList(
            [Expert(d_model, d_hidden, activation, dropout) for _ in range(self.num_experts)]
        )

    def _load_balancing_loss(self, probs: torch.tensor, top1_idx: torch.tensor):
        _, E= probs.shape

        # Calculate mean probability for each expert against all tokens
        p_i = probs.mean(dim=0)                                        # (E,)

        # Caculate fraction of tokens going to each expert 
        one_hot = F.one_hot(top1_idx, num_classes=self.num_experts).float()    # (T, E)
        f_i = one_hot.mean(dim=0)                                      # (E,)

        # Caculate loss
        loss = E * torch.sum(f_i * p_i)
        return loss
    
    def forward(self, x):
        """ 
        x : (B, S, D)
        """
        B, S, D = x.shape
        T = B * S
        x_flat = x.view(T, D)

        # Routing
        logits = self.router(x_flat)                                    # (T, E)
        probs = F.softmax(logits, dim=-1)                               # (T, E)
        topk_prob, topk_idx = torch.topk(probs, k= self.k, dim=-1)      # (T, k) (T, k)

        # Calculating argmax 
        top1_idx = torch.argmax(probs, dim=-1)                           # (T,)
        
        # Normalise topk values as sum is not equal to 1
        topk_prob = topk_prob / (topk_prob.sum(dim=-1, keepdim=True) + 1e-9)

        # Dispatch tokens to experts 
        y_flat = torch.zeros_like(x_flat)
        for i in range(self.k):
            expert_idx = topk_idx[:, i]                                  # (T,)
            expert_prob = topk_prob[:, i]

            for e, expert in enumerate(self.experts):
                mask = (expert_idx == e)                                 # (T,)
                if mask.any():
                    tokens_e = x_flat[mask]                              # (n_e, D)
                    out_e = expert(tokens_e)                             # (n_e, D)
                    # Weighed sum
                    y_flat[mask] += expert_prob[mask].unsqueeze(-1) * out_e

        y = y_flat.view(B, S, D)

        # Aux load balancing 
        aux_load = self._load_balancing_loss(probs, top1_idx)

        return y, aux_load

### Example Run

In [None]:
torch.manual_seed(42)

B, S, D = 4, 10, 32


moe= MoE(d_model=D, d_hidden= 64, num_experts= 4, k= 2, activation="relu", dropout=0.1)

x = torch.randn(B,S,D)
y, aux_loss= moe(x)

print("Output shape:", y.shape)          # (B, S, D)
print("Aux loss:", aux_loss.item())