# Implementing Tok-k Importance Routing 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Expert(nn.Module):
    def __init__(self, d_model, d_hidden, activation, dropout):
        super().__init__()
        act = {"relu":F.relu(), "gelu": F.gelu()}[activation]
        self.ffn = nn.ModuleList(
            nn.Sequential(
                nn.Linear(d_model, d_hidden),
                act,
                nn.Dropout(dropout),
                nn.Linear(d_hidden, d_model)
            )
        )

    def forward(self, x):
        return self.ffn(x)

### MoE Layer With Importance Routing

In [None]:
class MoE(nn.Module):
    def __init__(self, d_model: int, num_experts: int, k=2):
        super().__init__()
        self.num_experts = num_experts
        self.k = k

        # Defining the router 
        self.router = nn.Linear(d_model, num_experts)

    
    def forward(self, x):
        """ 
        x : (B, S, D)
        """
        B, S, D = x.shape
        T = B * S
        x_flat = x.view(T, D)

        # Routing
        logits = self.router(x_flat)                                    # (T, E)
        probs = F.softmax(logits, dim=-1)                               # (T, E)
        topk_prob, topk_idx = torch.topk(probs, k= self.k, dim=-1)      # (T, k) (T, k)
        
        # Normalise topk values as sum is not equal to 1
        topk_prob = topk_prob / (topk_prob.sum(dim=-1, keepdim=True) + 1e-9)

        # Dispatch tokens to experts 
        y_flat = torch.zeros_like(x_flat)
        for i in range(self.k):
            







SyntaxError: incomplete input (3155914306.py, line 1)

In [16]:
import torch

x = torch.tensor([0.1, 0.8, 0.05, 0.3, 0.6])
topk_vals, _ = torch.topk(x, 3)
topk_vals

tensor([0.8000, 0.6000, 0.3000])

In [17]:
topk_vals.sum(-1, keepdim=True)

tensor([1.7000])