# Report: SwitchHead MOE for Transformers – Architecture and Methodology

This report summarizes the SwitchHead method as described in the paper *"SwitchHead: Accelerating Transformers with Mixture-of-Experts Attention"* by Róbert Csordás, Piotr Piękos, Kazuki Irie, and Jürgen Schmidhuber (Accepted to NeurIPS 2024, [arXiv:2312.07987](https://doi.org/10.48550/arXiv.2312.07987)). The method introduces an efficient Mixture-of-Experts (MoE) mechanism for the self-attention layer, aiming to reduce both compute and memory requirements while matching the performance of standard Transformers.

---

## 1. Overview

Standard Transformer models compute self-attention for every head, leading to a large number of attention matrices and significant computational cost. While several recent works have applied MoE techniques to feedforward layers, applying MoE to the self-attention layer has proven challenging. **SwitchHead** is a novel approach that successfully applies an MoE mechanism to self-attention, reducing the number of attention computations by up to 8 times compared to the standard approach. This results in significant wall-clock speedups and reduced memory usage, without sacrificing language modeling performance.

---

## 2. Motivation and Background

In conventional self-attention, for an input sequence $x \in \mathbb{R}^{L \times d}$, each head computes:
$$
\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right)V,
$$
where the queries, keys, and values are derived as:
$$
Q = xW^Q,\quad K = xW^K,\quad V = xW^V,
$$
with $W^Q, W^K, W^V \in \mathbb{R}^{d \times d_k}$ and $d_k = d / H$ for $H$ heads.

The computational cost grows with the number of heads $H$ and the sequence length $L$. Previous MoE approaches for attention attempted to split the computation across experts but failed to achieve parity with a parameter-matched baseline.

SwitchHead addresses these issues by **routing tokens selectively to fewer attention computations**, thereby reducing the number of matrices computed without degrading performance.

---

## 3. SwitchHead Mechanism

### 3.1 Standard Self-Attention Recap

For a Transformer with $H$ attention heads, the model computes $H$ sets of attention matrices. For each head $h$, the output is:
$$
\text{head}_h(x) = \text{Attention}(xW_h^Q, xW_h^K, xW_h^V),
$$
and the concatenated output is projected:
$$
\text{MultiHead}(x) = \text{Concat}(\text{head}_1, \dots, \text{head}_H)W^O.
$$

### 3.2 The SwitchHead Approach

SwitchHead modifies the attention computation by introducing a gating mechanism that dynamically routes tokens to a subset of attention experts. Instead of computing an attention matrix for every head independently, SwitchHead computes **fewer attention matrices** based on the routing decisions.

1. **Gating Function:**  
   For each token representation $x_i$, a gating network computes a score over a set of experts:
   $$
   g(x_i) = \text{softmax}(W_g x_i),
   $$
   where $W_g \in \mathbb{R}^{d \times E}$ and $E$ is the number of experts.

2. **Expert Selection:**  
   For each token, the expert with the highest gate score is selected:
   $$
   e_i = \arg\max g(x_i).
   $$
   This selection means that tokens sharing similar properties may be processed together by the same expert.

3. **Reduced Attention Computations:**  
   Instead of computing $H$ full attention heads, SwitchHead **groups tokens** and computes only a subset of attention matrices. Let $M$ be the number of distinct attention matrices computed (with $M \ll H$). For each group (or expert), the standard attention computation is performed:
   $$
   \text{Attention}_{\text{expert}}(Q_e, K_e, V_e) = \text{softmax}\left(\frac{Q_e K_e^\top}{\sqrt{d_k}}\right)V_e,
   $$
   where $Q_e, K_e, V_e$ are the queries, keys, and values for tokens routed to expert $e$.

4. **Aggregation:**  
   The outputs from each expert are then combined to form the final multi-head attention output. This aggregation ensures that despite fewer computations, the model still captures diverse interactions across tokens.

By routing tokens and computing fewer attention matrices, SwitchHead achieves significant reductions in computation and memory usage while maintaining model performance.

---

## 4. Efficiency Gains: Compute and Memory Reduction

The SwitchHead method is reported to compute up to 8 times fewer attention matrices compared to the standard Transformer. This is achieved by:
- **Selective Routing:** Only computing the attention matrix for tokens that are grouped under the same expert.
- **Dynamic Allocation:** Adjusting the computation based on the input, so that many tokens can share an expert's attention computation.

Mathematically, if a standard Transformer computes:
$$
\text{TotalAttentionCost} \propto H \times L^2,
$$
SwitchHead reduces this cost approximately to:
$$
\text{TotalAttentionCost}_{\text{SwitchHead}} \propto M \times L^2,
$$
with $M \ll H$. The paper reports that for a model with 262M parameters, SwitchHead can achieve comparable perplexity to a standard model using only 44% of the compute and 27% of the memory.

---

## 5. Integration with MoE Feedforward Layers: "SwitchAll"

SwitchHead can be further extended by combining it with MoE feedforward layers. This fully-MoE Transformer, sometimes referred to as **SwitchAll**, applies mixture-of-experts mechanisms to both the self-attention and feedforward components. The combined approach offers additional efficiency gains while preserving—or even improving—the performance on language modeling and downstream tasks.

---

## 6. Summary

SwitchHead represents a significant step forward in making Transformer models more resource-efficient. By applying a Mixture-of-Experts approach to the self-attention layer, it:
- **Reduces compute and memory usage:** Up to 8 times fewer attention matrix computations.
- **Maintains high performance:** Matches the perplexity of standard Transformers while using less compute.
- **Offers flexibility:** Can be combined with MoE feedforward layers to build fully-MoE Transformers ("SwitchAll").

The SwitchHead mechanism is mathematically grounded in the use of gating functions and selective computation, ensuring that only the most relevant attention matrices are computed. This approach not only accelerates the Transformer but also opens new avenues for scalable and efficient language models.



In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

# ----------------------------
# Simple Tokenizer Definition
# ----------------------------
class SimpleTokenizer:
    def __init__(self, corpus):
        # Build vocabulary by splitting on whitespace.
        tokens = corpus.split()
        unique_tokens = set(tokens)
        # Reserve indices for special tokens.
        self.vocab = {"<PAD>": 0, "<UNK>": 1, "<MASK>": 2, "<EOS>": 3}
        for i, token in enumerate(unique_tokens):
            self.vocab[token] = i + 4
        self.inv_vocab = {i: token for token, i in self.vocab.items()}
        self.mask_token_id = self.vocab["<MASK>"]

    def encode(self, text):
        tokens = text.split()
        token_ids = [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens] + [self.vocab["<EOS>"]]
        return token_ids

    def decode(self, token_ids):
        tokens = [self.inv_vocab.get(i, "<UNK>") for i in token_ids]
        return " ".join(tokens)

# ----------------------------
# Switch FeedForward (Mixture-of-Experts)
# ----------------------------
class SwitchFeedForward(nn.Module):
    def __init__(self, d_model, num_experts=4, hidden_layers=10):
        """
        d_model: model embedding dimension.
        num_experts: number of feedforward experts.
        hidden_layers: number of (Linear+ReLU) layers per expert.
        """
        super(SwitchFeedForward, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList()
        for _ in range(num_experts):
            layers = []
            for _ in range(hidden_layers):
                layers.append(nn.Linear(d_model, d_model))
                layers.append(nn.ReLU())
            self.experts.append(nn.Sequential(*layers))
        # The gating network: projects each token's representation to a distribution over experts.
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: (seq_length, batch_size, d_model)
        gate_logits = self.gate(x)  # shape: (seq_length, batch_size, num_experts)
        # For each token, select the expert with the highest logit.
        selected_expert = gate_logits.argmax(dim=-1)  # shape: (seq_length, batch_size)

        # Prepare output tensor.
        output = torch.zeros_like(x)
        seq_length, batch_size, d_model = x.shape

        # Route tokens to the appropriate expert.
        for expert_idx, expert in enumerate(self.experts):
            # Create a mask for tokens routed to this expert.
            mask = (selected_expert == expert_idx)  # shape: (seq_length, batch_size)
            if mask.sum() == 0:
                continue
            # Extract tokens for this expert.
            # x_expert shape: (num_tokens, d_model)
            x_expert = x[mask].view(-1, d_model)
            # Process tokens through the expert.
            expert_output = expert(x_expert)
            # Place the expert outputs back into the output tensor.
            output[mask] = expert_output
        return output

# ----------------------------
# Transformer Block with MOE Option
# ----------------------------
class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, use_moe=False, num_experts=4, hidden_layers=10):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        if use_moe:
            self.feedforward = SwitchFeedForward(d_model, num_experts, hidden_layers)
        else:
            # Standard feedforward network: 10 layers of Linear+ReLU.
            ff_layers = []
            for _ in range(hidden_layers):
                ff_layers.append(nn.Linear(d_model, d_model))
                ff_layers.append(nn.ReLU())
            self.feedforward = nn.Sequential(*ff_layers)
        
    def forward(self, x):
        # x: (seq_length, batch_size, d_model)
        attn_output, _ = self.attn(x, x, x)
        x = self.layernorm1(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        return x

# ----------------------------
# GPT-2–like Model Definition with MOE Option
# ----------------------------
class GPT2Model(nn.Module):
    def __init__(self, vocab_size, d_model=255, n_heads=5, n_layers=4, max_seq_length=128, dropout=0.1, use_moe=False):
        super(GPT2Model, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = nn.Embedding(max_seq_length, d_model)
        
        # Create a stack of transformer blocks.
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout, use_moe=use_moe) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.max_seq_length = max_seq_length

    def forward(self, input_ids):
        # input_ids: (batch_size, seq_length)
        batch_size, seq_length = input_ids.size()
        positions = torch.arange(0, seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_length)
        x = self.token_embedding(input_ids) + self.positional_embedding(positions)
        x = x.transpose(0, 1)  # (seq_length, batch_size, d_model)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        x = x.transpose(0, 1)  # back to (batch_size, seq_length, d_model)
        logits = self.head(x)
        return logits

# ----------------------------
# Dataset Definition
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, text, tokenizer, seq_length=128, mask_prob=0.15):
        self.tokenizer = tokenizer
        tokens = tokenizer.encode(text)
        self.sequences = []
        for i in range(0, len(tokens) - seq_length, seq_length):
            self.sequences.append(tokens[i:i+seq_length])
        self.seq_length = seq_length
        self.mask_prob = mask_prob

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        input_seq = []
        target_seq = []
        for token in seq:
            if random.random() < self.mask_prob:
                input_seq.append(self.tokenizer.mask_token_id)
            else:
                input_seq.append(token)
            target_seq.append(token)
        return (torch.tensor(input_seq, dtype=torch.long),
                torch.tensor(target_seq, dtype=torch.long))

# ----------------------------
# Training Setup
# ----------------------------
def train(model, dataloader, optimizer, device, epochs=3):
    model.train()
    loss_fn = nn.CrossEntropyLoss(ignore_index=0)
    for epoch in range(epochs):
        total_loss = 0.0
        for batch_idx, (input_seq, target_seq) in enumerate(dataloader):
            input_seq = input_seq.to(device)
            target_seq = target_seq.to(device)
            optimizer.zero_grad()
            logits = model(input_seq)  # (batch_size, seq_length, vocab_size)
            loss = loss_fn(logits.view(-1, logits.size(-1)), target_seq.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1} Batch {batch_idx}: Loss = {loss.item():.4f}")
        print(f"Epoch {epoch+1} Average Loss: {total_loss / len(dataloader):.4f}")

# ----------------------------
# Text Generation Function
# ----------------------------
def generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, device="cpu"):
    model.eval()
    input_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
    generated = input_ids.tolist()[0]
    with torch.no_grad():
        for _ in range(max_length):
            input_seq = input_ids[:, -model.max_seq_length:]
            logits = model(input_seq)
            logits = logits[:, -1, :] / temperature
            probabilities = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probabilities, num_samples=1)
            input_ids = torch.cat([input_ids, next_token], dim=1)
            generated.append(next_token.item())
            if next_token.item() == tokenizer.vocab["<EOS>"]:
                break
    return tokenizer.decode(generated)

# ----------------------------
# Main Execution with MOE
# ----------------------------
if __name__ == "__main__":
    # Load the WikiText-2 raw training corpus from Hugging Face.
    dataset_hf = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    corpus = "\n".join(dataset_hf["text"])
    print("Corpus loaded. Corpus length:", len(corpus))

    tokenizer = SimpleTokenizer(corpus)
    dataset_obj = TextDataset(corpus, tokenizer, seq_length=128, mask_prob=0.15)
    dataloader = DataLoader(dataset_obj, batch_size=8, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vocab_size = len(tokenizer.vocab)
    # Set use_moe=True to enable the Switch Head Mixture-of-Experts in the transformer blocks.
    model = GPT2Model(vocab_size, d_model=255, n_heads=5, n_layers=4, max_seq_length=128, dropout=0.1, use_moe=True)
    model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=1e-4)

    # Train the model.
    train(model, dataloader, optimizer, device, epochs=3)

    # Generate sample text.
    prompt = "In a village"
    sample = generate_text(model, tokenizer, prompt, max_length=50, temperature=1.0, device=device)
    print("\nGenerated Text:")
    print(sample)


Corpus loaded. Corpus length: 10929707
Epoch 1 Batch 0: Loss = 11.4639
Epoch 1 Batch 10: Loss = 11.1988
Epoch 1 Batch 20: Loss = 10.7967
Epoch 1 Batch 30: Loss = 10.1381
Epoch 1 Batch 40: Loss = 9.8889
Epoch 1 Batch 50: Loss = 9.5045
Epoch 1 Batch 60: Loss = 9.2401
Epoch 1 Batch 70: Loss = 8.9139
Epoch 1 Batch 80: Loss = 8.6633
Epoch 1 Batch 90: Loss = 8.2594
Epoch 1 Batch 100: Loss = 8.1619
Epoch 1 Batch 110: Loss = 8.1151
Epoch 1 Batch 120: Loss = 7.5400
Epoch 1 Batch 130: Loss = 7.6229
Epoch 1 Batch 140: Loss = 7.4033
Epoch 1 Batch 150: Loss = 7.1092
Epoch 1 Batch 160: Loss = 6.9897
Epoch 1 Batch 170: Loss = 6.8470
Epoch 1 Batch 180: Loss = 6.9177
Epoch 1 Batch 190: Loss = 6.7889
Epoch 1 Batch 200: Loss = 6.7306
Epoch 1 Batch 210: Loss = 6.6031
Epoch 1 Batch 220: Loss = 6.4492
Epoch 1 Batch 230: Loss = 6.5301
Epoch 1 Batch 240: Loss = 6.2601
Epoch 1 Batch 250: Loss = 6.4749
Epoch 1 Batch 260: Loss = 6.1158
Epoch 1 Batch 270: Loss = 6.2580
Epoch 1 Batch 280: Loss = 6.0104
Epoch 1 Bat