# Report: Adversarial Training of a GPT-2 MOE Model with a GAN

This report describes the design and operation of a codebase that implements adversarial training for a GPT-2 style generator enhanced with a Mixture-of-Experts (MOE) architecture (SwitchHead) using a Generative Adversarial Network (GAN) framework. The goal is to train a generator to produce realistic text while a discriminator is trained to distinguish between real text and generated text. The adversarial feedback from the discriminator is used to improve the generator.

---

## 1. Overview

The code integrates two main components:

1. **Generator:**  
   A GPT-2–like language model that uses a SwitchHead MOE mechanism in its feedforward layers. It is designed to generate text by predicting the next token given a prompt.

2. **Discriminator:**  
   A binary classifier (implemented as a bidirectional LSTM) that distinguishes between real text from a dataset and text generated by the generator.

The training procedure follows a GAN paradigm where:
- The **discriminator** is trained to assign a high score (label 1) to real text and a low score (label 0) to generated (fake) text.
- The **generator** is updated with a REINFORCE-style gradient, using the discriminator’s feedback as a reward signal to encourage the production of text that the discriminator considers real.

---

## 2. Components

### 2.1 Generator (GPT-2 MOE Model)

The generator is a GPT-2 style model that employs a Transformer architecture with the following characteristics:

- **Token and Positional Embeddings:**  
  Each input token is mapped to a dense vector, and positional information is added to provide context on the token order.

- **Transformer Blocks with MOE (SwitchHead):**  
  Each Transformer block contains:
  - **Multi-Head Self-Attention:**  
    Computes attention using multiple heads to capture diverse relationships between tokens.
  - **Feedforward Network with Mixture-of-Experts:**  
    Instead of a standard feedforward network, a SwitchHead MOE module is used. It employs several experts (each a deep feedforward network) and a gating mechanism to dynamically route tokens to different experts. This reduces the number of attention matrix computations and can improve efficiency.

- **Autoregressive Generation:**  
  The generator has a `generate` method that takes an initial prompt (a sequence of token IDs) and generates text one token at a time. At each time step, the model computes a probability distribution over the vocabulary for the next token, samples a token (with temperature scaling), and appends it to the sequence. Generation stops when the `<EOS>` token is produced or after a predefined maximum length.

### 2.2 Discriminator

The discriminator is implemented as a simple LSTM-based binary classifier:
- **Embedding Layer:**  
  Converts token IDs into dense vectors.
- **Bidirectional LSTM:**  
  Processes the sequence to capture contextual information from both forward and backward directions.
- **Fully Connected Layer:**  
  Aggregates the final hidden states from the LSTM and produces a single logit. This logit is used to classify the sequence as real (1) or generated/fake (0).

### 2.3 Data Handling

- **Tokenizer:**  
  A simple whitespace tokenizer builds a vocabulary from the training corpus and supports encoding/decoding between text and token IDs.
- **Dataset:**  
  The real text dataset is constructed by loading text (e.g., from WikiText-2) and splitting it into fixed-length sequences. This dataset is used to provide real samples for training the discriminator.

---

## 3. GAN Training Procedure

The training loop follows a two-step update process in each iteration:

### 3.1 Discriminator Training

1. **Real Data Processing:**  
   - Real text sequences from the dataset are passed through the discriminator.
   - The discriminator computes logits and the binary cross-entropy (BCE) loss is calculated against the target label 1.

2. **Fake Data Generation and Processing:**  
   - For each batch, the generator produces fake sequences given a fixed prompt (e.g., "In a village").
   - For each generated sequence, a simplified computation of the log probability is performed.
   - The fake sequences are passed through the discriminator, and BCE loss is computed against the target label 0.

3. **Total Discriminator Loss:**  
   The losses from the real and fake data are summed and used to update the discriminator parameters.

### 3.2 Generator Training

1. **Reward Signal from Discriminator:**  
   - The discriminator’s output (after applying a sigmoid function) is interpreted as a reward signal. A higher value indicates that the discriminator considers the generated sequence as more real.
   
2. **REINFORCE-Style Update:**  
   - The generator’s loss is computed as the negative product of the log probability of the generated sequence and the reward. This encourages the generator to adjust its parameters to produce text that yields a higher reward from the discriminator.
   - The loss is averaged over the batch and used to update the generator’s parameters.

### 3.3 Alternating Updates

- The training loop alternates between updating the discriminator and the generator.
- Discriminator updates aim to improve the ability to distinguish real and fake text.
- Generator updates aim to generate text that fools the discriminator.

---

## 4. Mathematical Formulation

### Discriminator Loss

For a batch of size $B$, let:
- $x^{(i)}_{\text{real}}$ be real sequences with target label $y=1$.
- $x^{(i)}_{\text{fake}}$ be generated sequences with target label $y=0$.

The binary cross-entropy loss for the discriminator is:
$$
\mathcal{L}_D = -\frac{1}{B} \sum_{i=1}^{B} \left[ \log D(x^{(i)}_{\text{real}}) + \log (1 - D(x^{(i)}_{\text{fake}})) \right]
$$

### Generator Loss

For the generator, using REINFORCE, let $\log p(x^{(i)}_{\text{fake}})$ be the log probability of the generated sequence, and let $r^{(i)} = D(x^{(i)}_{\text{fake}})$ be the reward from the discriminator. The generator loss is:
$$
\mathcal{L}_G = -\frac{1}{B} \sum_{i=1}^{B} r^{(i)} \cdot \log p(x^{(i)}_{\text{fake}})
$$

---

## 5. Summary

The provided code implements a GAN for text generation where:
- The **generator** is a GPT-2 MOE model that leverages a SwitchHead mechanism to efficiently compute self-attention and generate text.
- The **discriminator** is an LSTM-based classifier that differentiates between real and generated text.
- **Adversarial training** is conducted by alternately updating the discriminator (to improve classification accuracy) and the generator (using a REINFORCE-style gradient based on discriminator rewards).

This setup aims to refine the generator’s capability to produce text that closely resembles real text by providing a discriminative signal, thereby improving the quality and realism of the generated language.


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset

# ----------------------------
# Simple Tokenizer Definition
# ----------------------------
class SimpleTokenizer:
    def __init__(self, corpus):
        # Build vocabulary by splitting on whitespace.
        tokens = corpus.split()
        unique_tokens = set(tokens)
        # Reserve indices for special tokens.
        self.vocab = {"<PAD>": 0, "<UNK>": 1, "<MASK>": 2, "<EOS>": 3}
        for i, token in enumerate(unique_tokens):
            self.vocab[token] = i + 4
        self.inv_vocab = {i: token for token, i in self.vocab.items()}
        self.mask_token_id = self.vocab["<MASK>"]

    def encode(self, text):
        tokens = text.split()
        token_ids = [self.vocab.get(token, self.vocab["<UNK>"]) for token in tokens] + [self.vocab["<EOS>"]]
        return token_ids

    def decode(self, token_ids):
        tokens = [self.inv_vocab.get(i, "<UNK>") for i in token_ids]
        return " ".join(tokens)

# ----------------------------
# SwitchFeedForward (MOE) and Transformer Block
# ----------------------------
class SwitchFeedForward(nn.Module):
    def __init__(self, d_model, num_experts=4, hidden_layers=10):
        """
        d_model: model embedding dimension.
        num_experts: number of feedforward experts.
        hidden_layers: number of (Linear+ReLU) layers per expert.
        """
        super(SwitchFeedForward, self).__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList()
        for _ in range(num_experts):
            layers = []
            for _ in range(hidden_layers):
                layers.append(nn.Linear(d_model, d_model))
                layers.append(nn.ReLU())
            self.experts.append(nn.Sequential(*layers))
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        # x: (seq_length, batch_size, d_model)
        gate_logits = self.gate(x)  # (seq_length, batch_size, num_experts)
        selected_expert = gate_logits.argmax(dim=-1)  # (seq_length, batch_size)
        output = torch.zeros_like(x)
        seq_length, batch_size, d_model = x.shape
        for expert_idx, expert in enumerate(self.experts):
            mask = (selected_expert == expert_idx)  # (seq_length, batch_size)
            if mask.sum() == 0:
                continue
            x_expert = x[mask].view(-1, d_model)
            expert_output = expert(x_expert)
            output[mask] = expert_output
        return output

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1, use_moe=False, num_experts=4, hidden_layers=10):
        super(TransformerBlock, self).__init__()
        self.attn = nn.MultiheadAttention(embed_dim=d_model, num_heads=n_heads, dropout=dropout)
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        if use_moe:
            self.feedforward = SwitchFeedForward(d_model, num_experts, hidden_layers)
        else:
            ff_layers = []
            for _ in range(hidden_layers):
                ff_layers.append(nn.Linear(d_model, d_model))
                ff_layers.append(nn.ReLU())
            self.feedforward = nn.Sequential(*ff_layers)
        
    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        x = self.layernorm1(x + self.dropout(attn_output))
        ff_output = self.feedforward(x)
        x = self.layernorm2(x + self.dropout(ff_output))
        return x

# ----------------------------
# GPT-2-like Generator with MOE (SwitchHead)
# ----------------------------
class GPT2Model(nn.Module):
    def __init__(self, vocab_size, d_model=255, n_heads=5, n_layers=4, max_seq_length=128, dropout=0.1, use_moe=False):
        super(GPT2Model, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = nn.Embedding(max_seq_length, d_model)
        self.layers = nn.ModuleList([
            TransformerBlock(d_model, n_heads, dropout, use_moe=use_moe) for _ in range(n_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.max_seq_length = max_seq_length

    def forward(self, input_ids):
        # input_ids: (batch_size, seq_length)
        batch_size, seq_length = input_ids.size()
        positions = torch.arange(0, seq_length, device=input_ids.device).unsqueeze(0).expand(batch_size, seq_length)
        x = self.token_embedding(input_ids) + self.positional_embedding(positions)
        x = x.transpose(0, 1)  # (seq_length, batch_size, d_model)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        x = x.transpose(0, 1)  # (batch_size, seq_length, d_model)
        logits = self.head(x)
        return logits

    def generate(self, prompt_ids, max_length=50, temperature=1.0, device="cpu"):
        self.eval()
        input_ids = prompt_ids.unsqueeze(0).to(device)  # shape: (1, seq_length)
        generated = input_ids.tolist()[0]
        with torch.no_grad():
            for _ in range(max_length):
                input_seq = input_ids[:, -self.max_seq_length:]
                logits = self.forward(input_seq)
                logits = logits[:, -1, :] / temperature
                probabilities = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probabilities, num_samples=1)
                # Debug: print next token index if it is out-of-range
                if torch.max(next_token).item() >= self.head.out_features:
                    print(f"Warning: Out-of-range token generated: {torch.max(next_token).item()} (vocab size: {self.head.out_features})")
                # Clamp next token to ensure it is within valid range
                next_token = torch.clamp(next_token, 0, self.head.out_features - 1)
                input_ids = torch.cat([input_ids, next_token], dim=1)
                generated.append(next_token.item())
                if next_token.item() == 3:  # <EOS>
                    break
        return torch.tensor(generated, device=device)

# ----------------------------
# Discriminator: A simple LSTM-based binary classifier
# ----------------------------
class Discriminator(nn.Module):
    def __init__(self, vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.1):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.lstm = nn.LSTM(embed_dim, hidden_dim, num_layers, batch_first=True, dropout=dropout, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 1)  # bidirectional LSTM
        self.dropout = nn.Dropout(dropout)

    def forward(self, input_ids):
        # input_ids: (batch_size, seq_length)
        emb = self.embedding(input_ids)
        out, _ = self.lstm(emb)
        # Use the last hidden state from both directions
        out = self.dropout(out[:, -1, :])
        logits = self.fc(out)
        return logits

# ----------------------------
# Dataset Definition (real text)
# ----------------------------
class TextDataset(Dataset):
    def __init__(self, text, tokenizer, seq_length=128):
        self.tokenizer = tokenizer
        tokens = tokenizer.encode(text)
        self.sequences = []
        for i in range(0, len(tokens) - seq_length, seq_length):
            self.sequences.append(tokens[i:i+seq_length])
        self.seq_length = seq_length

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq = self.sequences[idx]
        return torch.tensor(seq, dtype=torch.long)

# ----------------------------
# GAN Training Setup
# ----------------------------
def train_gan(generator, discriminator, real_dataloader, gen_optimizer, disc_optimizer, tokenizer, device, 
              num_epochs=3):
    bce_loss = nn.BCEWithLogitsLoss()

    generator.train()
    discriminator.train()
    
    for epoch in range(num_epochs):
        for batch_idx, real_batch in enumerate(real_dataloader):
            real_batch = real_batch.to(device)
            batch_size = real_batch.size(0)

            ### --- Discriminator Training ---
            disc_optimizer.zero_grad()
            # Real text processing
            real_logits = discriminator(real_batch)
            real_labels = torch.ones((batch_size, 1), device=device)
            real_loss = bce_loss(real_logits, real_labels)

            # Generate fake sequences using the generator with a fixed prompt.
            prompt_text = "In a village"
            encoded_prompt = tokenizer.encode(prompt_text)
            print(f"Encoded prompt: {encoded_prompt}")
            prompt = torch.tensor(encoded_prompt, dtype=torch.long, device=device)
            
            fake_sequences = []
            log_probs = []  # store log probabilities for REINFORCE update
            for _ in range(batch_size):
                gen_seq = generator.generate(prompt, max_length=generator.max_seq_length, device=device)
                fake_sequences.append(gen_seq)
                # Compute log probabilities for generated sequence (simplified)
                with torch.no_grad():
                    logits = generator(gen_seq.unsqueeze(0))
                    # Gather log probabilities for each token in the sequence
                    log_probs_seq = F.log_softmax(logits, dim=-1)
                    # Reshape to (1, seq_length, vocab_size) and gather the corresponding token probabilities
                    token_ids = gen_seq.unsqueeze(0).unsqueeze(-1)
                    gen_log_prob = log_probs_seq.gather(2, token_ids).sum()
                    log_probs.append(gen_log_prob)
            fake_sequences = nn.utils.rnn.pad_sequence(fake_sequences, batch_first=True, 
                                                       padding_value=tokenizer.vocab.get("<PAD>", 0))
            fake_logits = discriminator(fake_sequences)
            fake_labels = torch.zeros((batch_size, 1), device=device)
            fake_loss = bce_loss(fake_logits, fake_labels)

            disc_loss = real_loss + fake_loss
            disc_loss.backward()
            disc_optimizer.step()

            ### --- Generator Training ---
            gen_optimizer.zero_grad()
            # Get rewards from discriminator for generated sequences
            rewards = torch.sigmoid(discriminator(fake_sequences)).detach()  # rewards in [0,1]
            gen_loss = 0.0
            for r, lp in zip(rewards, log_probs):
                gen_loss += -r * lp
            gen_loss = gen_loss / batch_size
            gen_loss.backward()
            gen_optimizer.step()

            if batch_idx % 10 == 0:
                print(f"Epoch {epoch+1} Batch {batch_idx}: Disc Loss = {disc_loss.item():.4f}, Gen Loss = {gen_loss.item():.4f}")

# ----------------------------
# Main Execution
# ----------------------------
if __name__ == "__main__":
    # Load WikiText-2 raw training corpus from Hugging Face.
    dataset_hf = load_dataset("wikitext", "wikitext-2-raw-v1", split="train")
    corpus = "\n".join(dataset_hf["text"])
    print("Corpus loaded. Corpus length:", len(corpus))

    # Initialize tokenizer and dataset.
    tokenizer = SimpleTokenizer(corpus)
    real_dataset = TextDataset(corpus, tokenizer, seq_length=128)
    real_dataloader = DataLoader(real_dataset, batch_size=8, shuffle=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    vocab_size = len(tokenizer.vocab)

    # Instantiate generator (GPT-2 MOE model with SwitchHead) and discriminator.
    generator = GPT2Model(vocab_size, d_model=255, n_heads=5, n_layers=4, max_seq_length=128, dropout=0.1, use_moe=True)
    generator.to(device)
    discriminator = Discriminator(vocab_size, embed_dim=128, hidden_dim=256, num_layers=2, dropout=0.1)
    discriminator.to(device)

    # Define optimizers.
    gen_optimizer = optim.Adam(generator.parameters(), lr=1e-4)
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=1e-4)

    # Train the GAN.
    train_gan(generator, discriminator, real_dataloader, gen_optimizer, disc_optimizer, tokenizer, device, num_epochs=3)

    # Generate sample text from the generator after GAN training.
    prompt_text = "In a village"
    encoded_prompt = tokenizer.encode(prompt_text)
    prompt = torch.tensor(encoded_prompt, dtype=torch.long, device=device)
    gen_ids = generator.generate(prompt, max_length=50, temperature=1.0, device=device)
    print("\nGenerated Text:")
    print(tokenizer.decode(gen_ids.tolist()))


Corpus loaded. Corpus length: 10929707
Encoded prompt: [58660, 39094, 23037, 3]


../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [96,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [97,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [98,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [99,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [100,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0], thread: [101,0,0] Assertion `srcIndex < srcSelectDimSize` failed.
../aten/src/ATen/native/cuda/Indexing.cu:1308: indexSelectLargeIndex: block: [86,0,0],

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
