# nanoGPT + MoE Demo
We follow the variant of MoE in [Mixtral](https://mistral.ai/news/mixtral-of-experts/).

---

### MoE Modules
The following code snippet is the same as `moe_modules/moe.py`.

#### Router

In [12]:
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F

from functools import partial
from jaxtyping import Float, Int
from typing import Tuple

from einops import rearrange, repeat

from model import LayerNorm, CausalSelfAttention
from moe_modules.moe import Top2MLP

class Router(nn.Module):
    def __init__(self,
                embed_dim: int,
                n_exp: int = 8,
                ):
        super.__init__()
        # Set up router net
        self.gate_net = nn.Linear(embed_dim, n_exp, bias=False)
        
    def forward(self,
                embeds: Float[Tensor, "*batch token embed"])->Float[Tensor, "*batch token exp"]:
     
        logits_exp = self.gate_net(embeds)
        return logits_exp

#### MoE Layer

In [13]:
class MoELayer(nn.Module):
    def __init__(self,
            config,
            ):
        super().__init__()
        
        # Configs
        self.n_exp_per_token = config.n_exp_per_token
        self.jitter_noise = config.router_jitter_noise 
        self.embed_dim = config.n_embed
        self.n_exp = config.n_exp
        
        # Set up experts
        self.exps = nn.ModuleList([Top2MLP(config) for _ in range(config.n_exp)])
        
        # Set up router
        self.router = Router(embed_dim = self.embed_dim, n_exp = self.n_exp, n_exp_per_token=self.n_exp_per_token)
        
    def forward(self, 
                x: Float[Tensor, "*batch token embed"]):
        
        batch, n_token, hidden_dim = x.size(0), x.size(-2), x.size(-1)
        
        # Add random noise for capacity balance 
        if self.training and self.jitter_noise > 0:
            x *= torch.empty_like(x).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
        
        # routing
        x = rearrange(x, "batch token embed -> (batch token) embed")
        logits_exps = self.router(x)
        weights_exps = F.softmax(logits_exps, dim=-1, dtype=torch.float)
        weights_selected_exps, idx_selected_exps = torch.topk(weights_exps, self.n_exp_per_token, dim=-1) 
        weights_selected_exps = weights_selected_exps.to(x.dtype)
        
        # Exps forwarding 
        final_x = torch.zeros_like(x)
        
        mask_exps = rearrange(F.one_hot(idx_selected_exps, num_classes=self.n_exp), "batch_token exp_per_token n_exp -> n_exp exp_per_token batch_token")
        
        for id, exp in enumerate(self.exps):
            idx_exp, top_x = torch.where(mask_exps[id])
            
            # Indexing token for current expert
            x_curr = x[None, top_x].reshape(-1, hidden_dim)
            x_curr_hidden = exp(x_curr) * weights_exps[top_x, idx_exp, None]
            
            final_x.index_add_(0, top_x, x_curr_hidden.to(x.dtype))
            
        final_x = final_x.reshape(batch, n_token, hidden_dim)
        
        return final_x, logits_exps 

#### MoE Block

In [14]:
class MoEBlock(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
        self.moe = MoELayer(config)

    def forward(self, x):
        
        # Self Attention
        x = x + self.attn(self.ln_1(x))

        # FC w/ Experts
        residual = x
        x, logits_exps = self.moe(self.ln_2(x))
        x = residual+ x
        
        return x, logits_exps

#### Balance Loss

In [15]:
def balancing_loss_func(
    all_router_logits: Tuple[Float[Tensor, "*batch token exps"]],
    n_exp: int,
    n_exp_per_token=int,
) -> Tensor:

    all_router_logits = torch.cat([logits_exp for logits_exp in all_router_logits], dim=0)
    
    # Same as routing in MoELayer
    weights_exps = F.softmax(all_router_logits, dim=-1)

    _, idx_selected_exps = torch.topk(weights_exps, n_exp_per_token, dim=-1)

    mask_exps = F.one_hot(idx_selected_exps, n_exp)
 
    n_tokens_per_exp = torch.mean(mask_exps.float(), dim=0)

    prob_per_exp = torch.mean(weights_exps, dim=0)

    overall_loss = torch.sum(n_tokens_per_exp * prob_per_exp.unsqueeze(0))
    return overall_loss * n_exp

### Experiments

#### Baseline GPT

In [6]:
!python train.py config/train_shakespeare_char.py --compile=False --max_iters=1000

Overriding config with config/train_shakespeare_char.py:
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equal to max_iters usually
min_lr = 1e-4 # learning_rate / 10 usually
beta2 = 0.99 # make a bit bigger because number of 

#### MoE
Setting:
```python
# MoE Setting
use_MoE = True 
alpha_balance_loss = 0.01
n_exp = 4
n_exp_per_token = 2
router_jitter_noise = -1.0
```

In [10]:
!python train.py config/train_shakespeare_char_moe.py --compile=False --max_iters=1000 --n_exp=4

Overriding config with config/train_shakespeare_char_moe.py:
# train a miniature character-level shakespeare model
# good for debugging and playing on macbooks and such

out_dir = 'out-shakespeare-char'
eval_interval = 250 # keep frequent because we'll overfit
eval_iters = 200
log_interval = 10 # don't print too too often

# we expect to overfit on this small dataset, so only save when val improves
always_save_checkpoint = False

wandb_log = False # override via command line if you like
wandb_project = 'shakespeare-char'
wandb_run_name = 'mini-gpt'

dataset = 'shakespeare_char'
gradient_accumulation_steps = 1
batch_size = 64
block_size = 256 # context of up to 256 previous characters

# baby GPT model :)
n_layer = 6
n_head = 6
n_embd = 384
dropout = 0.2

# MoE Setting
use_MoE = True 
alpha_balance_loss = 0.01
n_exp = 8
n_exp_per_token = 2
router_jitter_noise = -1.0

learning_rate = 1e-3 # with baby networks can afford to go a bit higher
max_iters = 5000
lr_decay_iters = 5000 # make equ

MoE training is slower due to 4x MLP(experts) and additional router. But it gets lower val loss (balance loss is not added during evalutaion for fair comparision).