# 📌 Implementing Switch Mixture-of-Experts (Top-1 Routing)
The Switch Transformer introduces a scalable version of the Mixture-of-Experts (MoE) architecture, where each token is routed to only one expert (top-1) instead of multiple experts.

This makes it:
- Efficient → reduces computation and memory cost
- Scalable → enables training with billions of parameters
- Balanced → uses an auxiliary loss to encourage fair token distribution across experts

In [1]:
# Import Libraries
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

### 💡Defining SwitchMoE Class 
- Implements a lightweight and efficient Mixture-of-Experts (MoE) layer inspired by the Switch Transformer

#### 💡Key points of the implementation:

- Router (Gating Network): A linear layer that routes each token to exactly one expert (top-1 selection)
- Experts: Multiple feed-forward networks (FFNs), each acting as a specialist
- Auxiliary Loss: A balancing loss that encourages tokens to be distributed fairly across experts, avoiding overloading a few experts

#### 💡Forward Pass:

- Compute routing probabilities
- Select top-1 expert for each token 
- Dispatch tokens to chosen experts.
- Combine outputs and return along with auxiliary loss.

This design replaces the standard feed-forward network in a Transformer block, making it more scalable and compute-efficient

In [None]:
class SwitchMoE(nn.Module):
    """
    Switch Transformer-syle MoE layer:
        - Replaces the dence FFN in transformer block
        - Uses a learned router to pick exactly 1 expert-per-token (top-1 routing)
    
    Shapes:
        - x: (batch_size, seq_len, d_model) with batch_first=True
        - returns: (batch_size, seq_len, d_model), aux_loss: scaler
    """
    def __init__(self, d_model: int, num_experts: int, d_ff: int, dropout: float = 0.0, activation: str = 'relu'):
        super().__init__()

        # Router : linear -> logits
        self.router= nn.Linear(d_model, num_experts)
        self.num_experts = num_experts

        # Experts : Each is a standard FFN (Linear -> Activation -> Dropout -> Linear -> Dropout)
        act = {"relu": nn.ReLU(), "gelu": nn.GELU()}[activation]
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                act,
                nn.Dropout(dropout),
                nn.Linear(d_ff, d_model)
            )

            for _ in range(num_experts)
        ])

    def _compute_routing(self, x_flat):
        """
        x_flat: (T, S, d_model) where T = batch * seq_len
        Returns:
            top1_idx: (T,) expert indices
            probs: (T, E) softmax prob over expert
            """
        logits = self.router(x_flat)                        # (T, E)
        probs = F.softmax(logits, dim=-1)                   # (T, E)
        top1_idx = torch.argmax(probs, dim=-1)              # (T,)
        return top1_idx, probs 

    def _load_balancing_loss(self, probs, top1_idx):
        """
        probs: (T, E) softmax over experts for each token
        top1_idx: (T, ) chosen expert per token 
        """ 
        T, E= probs.shape                           
        p_i = probs.mean(dim=0)                              # (E,)
        one_hot = F.one_hot(top1_idx, num_classes=E).float() # (T, E)
        f_i = one_hot.mean(dim=0)                            # (E,)
        loss = E * torch.sum(p_i * f_i)
        return loss
    
    def forward(self, x):
        """
        x: (B, S, d_model)
        returns: 
            y: (B, S, d_model)
            aux_loss: scaler tensor
        """
        B, S, D= x.size()
        T = B * S
        x_flat = x.reshape(T, D)                             # (T, d_model)

        # While routing, we need to stop grad only for argmax choice
        with torch.no_grad:
            top1_idx, probs = self._compute_routing(x_flat)
        
        # Compute the aux loss (with autograd on probs)
        logits = self.router(x_flat)                        # (T, E)
        probs_grad = F.softmax(logits, dim=-1)
        aux_loss = self._load_balancing_Loss(probs_grad, top1_idx)

        # Dispath tokens to their respective experts 
        y_flat = torch.zeros_like(x_flat)
        for e in range(self.num_experts):
            mask = (top1_idx == e)
            if mask.any():
                tokens_e = x_flat[mask]                     # (n_e, d_model)
                out_e = self.experts[e](tokens_e)           # (n_e, d_model)
                y_flat[mask] = out_e
        
        y = y_flat.view(B, S, D)
        return y, aux_loss