# Module 2 Project 3: Mixture of Experts
- Implement a [MoE model](https://huggingface.co/blog/AviSoori1x/makemoe-from-scratch) using everything that has been covered in previous modules so far
- This is pretty much a standard transformer model but with the added benefit of having multiple, smaller feed forward networks instead of the traditional FF layer, and a Router / Gating network to determine which experts to send each token to.

## STEP 1: IMPORTS
- We need the usual `torch` imports, as well as `os` and `re`
- We use `requests` to pull our data from Project Gutenberg

In [None]:
import requests
import re
import os
import torch
import torch.nn as nn
from torch.nn import functional as F

## STEP 2: HYPERPARAMETERS
- Setting up our hyperparameters - we choose 8 heads of attention with a head size of 16
- Our context size here is 32 tokens, and our embed dimension is 128, so this is a small network trained on one text file
- We also choose 8 experts as our smaller feed forward networks within the model
- We set top k to 2 to choose the top 2/8 experts for each token

In [None]:
# Hyperparameters
learning_rate = 1e-3
n_embed = 128
n_head = 8
n_layer = 8
head_size = 16
dropout = 0.1
context_size = 32
num_experts = 8
top_k = 2
device = "cpu" # or "cuda"
batch_size = 16
eval_interval = 100
eval_iters = 400

## STEP 3: DATA
- Now we can collect our data - Moby Dick from Project Gutenberg
- Download the text into a variable and print the first 1000 characters

In [None]:
# Moby Dick text file from Project Gutenberg site
resp = requests.get("https://www.gutenberg.org/cache/epub/2701/pg2701.txt")

# Tags for text filtering
start = "*** START OF THE PROJECT GUTENBERG EBOOK MOBY DICK; OR, THE WHALE ***"
end = "*** END OF THE PROJECT GUTENBERG EBOOK MOBY DICK; OR, THE WHALE ***"

result = resp.text[resp.text.find(start):resp.text.find(end)]

print(result[:1000])

## STEP 4: TOKENIZATION
- Now we can set up our tokenization, which is the basic character level tokenizer we have been using in this module
- We get the vocab size as the number of distinct characters in the data, as well as creating encoding and decoding functions to convert the data into tokens and back
- Lastly, we encode the data and split into training and validation sets (90/10)

In [None]:
# Get vocab size
chars = sorted(list(set(result)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

# Basic tokenizaiton as we've done before
stoi = {ch:i for i, ch in enumerate(chars)}
itos = {i:ch for i, ch in enumerate(chars)}

encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Test the tokenizer
print(encode("Hello test!"))
print(decode(encode("Hello test!")))

# Encode our dataset
data = torch.tensor(encode(result), dtype=torch.long)

# Train and validation split
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

## STEP 5: UTIL METHODS
- We need to create some util methods to use with our model when training
- Namely we need a method to load a batch of data, and a method to estimate our loss
- Both of these are not expressly needed for this code, just nice-to-have and makes things cleaner at the end

In [None]:
# Method to get a batch (from Project 1)
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - context_size, (batch_size,))
    x = torch.stack([data[i:i+context_size] for i in ix])
    y = torch.stack([data[i+1:i+context_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

# Method to estimate loss (also from Project 1)
@torch.no_grad()
def estimate_loss(model):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

## STEP 6: ATTENTION
- We can now implement Multi Head Attention much like we have in the previous modules
- We don't need to dive too far into this here, as we've done it before, and the implementation is the same
- We create our AttentionHead module, and then combine heads in our MHA implementation
- Nothing changes from previous implementations of this

In [None]:
# Attention implementation is re-used from the last few projects
class AttentionHead(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.k = nn.Linear(n_embed, head_size, bias=False)
        self.q = nn.Linear(n_embed, head_size, bias=False)
        self.v = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(context_size, context_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        key = self.k(x)
        query = self.q(x)
        weights = query @ key.transpose(-2,-1) * C**-0.5
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(weights, dim=-1)
        weights = self.dropout(weights)
        value = self.v(x)
        result = weights @ value
        return result
    
# Same as above, MHA is copied exactly from Projects 1 and 2 - no need to change anything :-)
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([AttentionHead(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

## STEP 7: MLP EXPERT AND ROUTER
- Now is where the MoE-specific changes are needed.
- We first create a Multi Layer Perceptron - our 'Expert' - that is the same as our usual MLP layer in classical Transformers
- We then build the Top K Router, with k=2, to choose the top 2 experts to route tokens to. This is simply a linear projection from our embed dimensino to our # of experts
- We use softmax and masking to round out the top 2 results to logits (summing to 1) and the remaining experts' values are 0.

In [None]:
# Here is where our FFN implementation changes, we rename it to our Expert module
# It is still an MLP with ReLU activation, there are now num_experts of them
class MLPExpert(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.block = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.block(x)
    
# Our Router or 'gating' module is used to determine which expert(s) receive a given token
# Simply a linear projection from embedding dimension to num_experts
# This is then softmaxed after setting non-topk values to negative infinity
# The softmax returns sparse logits that sum to 1.0
class TopkRouter(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(TopkRouter, self).__init__()
        self.top_k = top_k
        self.linear =nn.Linear(n_embed, num_experts)
    
    def forward(self, attention_output):
        logits = self.linear(attention_output)
        top_k_logits, indices = logits.topk(self.top_k, dim=-1)
        zeros = torch.full_like(logits, float('-inf'))
        sparse_logits = zeros.scatter(-1, indices, top_k_logits)
        router_output = F.softmax(sparse_logits, dim=-1)
        return router_output, indices

## STEP 8: SPARSE MOE
- Now we can build our Sparse MoE Expert + Gating network module
- This is the combination of everything we did in the last step, combining our router network with an MLPExpert module block for each expert
- In the forward pass, we process each expert in parallel, and creating a mask for the inputs where the current expert is in top-k
- If the mask contains values (we are in top-k) - process through the expert, apply gating scores, and sum the outputs across experts
- Then, we return the final output

In [None]:
# We can now build our MoE module for our Transformer
class SparseMoE(nn.Module):
    def __init__(self, n_embed, num_experts, top_k):
        super(SparseMoE, self).__init__()
        # Set up router, and make a list of num_experts MLP modules
        self.router = TopkRouter(n_embed, num_experts, top_k)
        self.experts = nn.ModuleList([MLPExpert(n_embed) for _ in range(num_experts)])
        self.top_k = top_k

    def forward(self, x):
        gating_output, indices = self.router(x)
        final_output = torch.zeros_like(x)
        flat_x = x.view(-1, x.size(-1))
        flat_gating_output = gating_output.view(-1, gating_output.size(-1))

        # Iterate over experts
        for i, expert in enumerate(self.experts):

            expert_mask = (indices == i).any(dim=-1)
            flat_mask = expert_mask.view(-1)

            # If we have a logit, apply to expert and score
            if flat_mask.any():
                expert_input = flat_x[flat_mask]
                expert_output = expert(expert_input)

                gating_scores = flat_gating_output[flat_mask, i].unsqueeze(1)
                weighted_output = expert_output * gating_scores

                # Sum the outputs and return
                final_output[expert_mask] += weighted_output.squeeze(1)

        return final_output

## STEP 9: TRANSFORMER BLOCK
- Finally, we can build our Transformer, replacing the classic MLP with the SpareMoE module we created above
- Otherwise, the Transformer should be built and function the same way.
- We pass our inputs through the MHA block, and they are then passed to the Router and Expert modules
- Inputs are normalized before each module, and skip (residual) connections are used as well before returning the output.

In [None]:
# Transformer block like we've done before, but using our SparseMoE module instead of a FF layer
class TransformerBlock(nn.Module):
    def __init__(self, n_embed, n_head, num_experts, top_k):
        super().__init__()
        head_size = n_embed // n_head
        self.attention = MultiHeadAttention(n_head, head_size)
        self.moe = SparseMoE(n_embed, num_experts, top_k)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.attention(self.ln1(x))
        x = x + self.moe(self.ln2(x))
        return x

## STEP 10: LLM
- We can build our Language Model using our Transformer block created above.
- We set up our token and positional embeddings to process inputs
- We also build our Transformer layers and add a final layer normalization before the LM head (linear projection from embed dimension to our vocab size)
- The forward process is the same as we usually do with Transformer language models
- We also add a generate method to run prediction / inference with the model without updating loss, simply by appending single token predictions in a loop

In [None]:
# Finally, we build our MoE language model
class SparseMoELanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        # Token and positional embeddings
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.position_embedding_table = nn.Embedding(context_size, n_embed)

        # num_layers layers of num_experts MLP modules each with n_head attention heads
        self.blocks = nn.Sequential(*[TransformerBlock(n_embed, n_head=n_head, num_experts=num_experts, top_k=top_k) for _ in range(n_layer)])
        self.ln_f = nn.LayerNorm(n_embed) # final layer norm

        # LM head for output projection to vocab (token prediction)
        self.lm_head = nn.Linear(n_embed, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        # Embed and sum
        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb

        # Pass through our building block layers
        x = self.blocks(x)
        x = self.ln_f(x)

        # Output of LM head is logits
        logits = self.lm_head(x)

        # Generation step if no targets
        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    # Generate method to return new tokens up to max_new_tokens
    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -context_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)
        return idx

## STEP 11: TRAIN
- For our training loop, we initialize our weights with [Xavier / Glorot initialization](https://365datascience.com/tutorials/machine-learning-tutorials/what-is-xavier-initialization/)
- We set max iterations to 5000, and initialize our model, printing the # of parameters
- We then create our AdamW optimizer and start our training loop
- We print the loss values every 500 iterations
- Update the weights and loop until we have completed training

In [None]:
# Xavier initialization
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform(m.weight)

max_iters = 5000
    
# Instantiate and initialize weights    
model = SparseMoELanguageModel()
model.apply(init_weights)

# Get # of parameters
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# Training loop
for iter in range(max_iters):

    # Every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0 or iter == max_iters - 1:
        losses = estimate_loss(model)
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")


    # Sample a batch of data
    xb, yb = get_batch('train')

    # Evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

## STEP 12: GENERATE
- Now that training is done, we can use our generate method to display outputs
- We initialize input as a 1x1 tensor of zero values, and pass it to the model to generate 2000 additional tokens
- The model will print the result after it completes generating

In [None]:
# Generation step, print 2000 new tokens from empty 1x1 tensor
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))