# Session 4, Part 1 – Training a Transformer-Based Text Generator

In this session, we will gain **hands-on experience** training a **Transformer-based text generator** (e.g., a GPT-like model) on a small dataset. We’ll cover:

- **Model Architecture**: building a minimal decoder-only Transformer with causal masking.  
- **Training Routines**: batch scheduling, gradient accumulation, and optimization.  
- **Inference Techniques**: *greedy decoding*, *temperature scaling*, and *top-k* sampling.

By the end, you’ll be able to **train** a basic text-generating Transformer and **experiment** with different hyperparameters and inference strategies to observe how they affect text quality.

---

## Table of Contents

1. [Introduction to Text Generation](#introduction)
2. [Dataset Preparation](#dataset-prep)
3. [Building a GPT-like Model](#gpt-model)
   - [Causal Masking](#causal-masking)
   - [Model Definition](#model-def)
   - [Training Routine (Optimization & Strategies)](#training-routine)
4. [Inference Techniques](#inference)
   - [Greedy Decoding vs. Random Sampling](#greedy-random)
   - [Temperature Scaling](#temperature)
   - [Top-k Sampling](#topk)
5. [Practical Exercises](#exercises)
6. [Conclusion](#conclusion)


## <a id="introduction"></a>1. Introduction to Text Generation

**Text generation** is a fundamental language modeling task. We model $P(x_{t} \mid x_{0}, \dots, x_{t-1})$ to predict the next token given the previous tokens. **Transformer**-based **decoder-only** architectures (like GPT) excel at *autoregressive* generation for tasks like writing paragraphs, summarizing, or code generation.

**Key Insight**:  
- We mask out future tokens so that at each time step $t$, the model can’t “see” tokens $t+1, t+2, \dots$.  
- This is often referred to as **causal language modeling** or **autoregressive** language modeling.



### Why Use a Transformer (GPT-like) for Generation?

1. **Parallel Computation** of the hidden states (although we generate tokens one-by-one at inference time, training can process full sequences in parallel).
2. **Long-range Dependencies**: Self-attention can capture wide contexts better than many RNNs.
3. **Scalability**: We can scale up to large models (e.g., GPT-2, GPT-3) and achieve impressive quality.

## <a id="dataset-prep"></a>2. Dataset Preparation

To demonstrate, we’ll use a **small text dataset**: **Tiny Shakespeare**. You can adapt these steps to any text corpus.

### 2.1 Download and Load the Dataset

In [11]:

import os
import requests

# Download the tiny shakespeare dataset if not present
if not os.path.exists("tiny_shakespeare.txt"):
    url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    r = requests.get(url)
    with open("tiny_shakespeare.txt", "wb") as f:
        f.write(r.content)

with open("tiny_shakespeare.txt", "r", encoding="utf-8") as f:
    text_data = f.read()

print("Length of text:", len(text_data))
print("Sample:\n", text_data[:250])

Length of text: 1115394
Sample:
 First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.




### 2.2 Character-Level Tokenization

We’ll keep it simple: **character-level** tokens.


In [13]:
chars = sorted(list(set(text_data)))
vocab_size = len(chars)
print("Vocab size:", vocab_size)
print("Chars:", chars[:50])

char2idx = {ch: i for i, ch in enumerate(chars)}
idx2char = {i: ch for ch, i in char2idx.items()}

# Convert entire text to indices
data_indices = [char2idx[ch] for ch in text_data]
print("Example of mapped indices:", data_indices[:50])

Vocab size: 65
Chars: ['\n', ' ', '!', '$', '&', "'", ',', '-', '.', '3', ':', ';', '?', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k']
Example of mapped indices: [18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44, 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63, 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56]


### 2.3 Creating (x, y) Sequences

We want to split data into sequences of a fixed length (`seq_length`), with target `y` being the next character for each position in `x`.


In [15]:
import torch
from torch.utils.data import Dataset, DataLoader

class CharDataset(Dataset):
    def __init__(self, data, seq_len=64):
        self.data = data
        self.seq_len = seq_len
        self.num_samples = len(self.data) // self.seq_len - 1
    
    def __len__(self):
        return self.num_samples
    
    def __getitem__(self, idx):
        start_idx = idx * self.seq_len
        x = self.data[start_idx : start_idx + self.seq_len]
        y = self.data[start_idx + 1 : start_idx + self.seq_len + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

seq_length = 64
dataset = CharDataset(data_indices, seq_len=seq_length)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, drop_last=True)
print("Number of batches:", len(dataloader))


Number of batches: 544


## <a id="gpt-model"></a>3. Building a GPT-like Model

### <a id="causal-masking"></a>3.1 Causal Masking

A **causal mask** ensures each position in the sequence can only attend to itself and previous positions. This is typically implemented with a **triangular** (lower-triangular) matrix.

At each time step $t$, the model sees positions $\leq t$.

In [17]:
import math
import torch.nn as nn

def generate_causal_mask(seq_len):
    """
    Returns a (seq_len, seq_len) mask
    where positions j>i are set to False, preventing 'future' attention.
    """
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask == 1  # boolean

### <a id="model-def"></a>3.2 Model Definition

**GPT** is essentially:
1. An **embedding** layer (tokens + optional positions).
2. A stack of **decoder blocks** with causal self-attention.
3. A final linear layer that outputs a distribution over the vocabulary.

#### 3.2.1 Self-Attention (Causal)

In [18]:

class MultiHeadCausalSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        
        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, x, mask=None):
        B, T, C = x.shape
        # Project to queries, keys, values
        q = self.Wq(x).view(B, T, self.num_heads, self.head_dim).permute(0,2,1,3)
        k = self.Wk(x).view(B, T, self.num_heads, self.head_dim).permute(0,2,1,3)
        v = self.Wv(x).view(B, T, self.num_heads, self.head_dim).permute(0,2,1,3)
        
        # Scaled dot-product
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
        if mask is not None:
            # mask shape => (T, T), broadcast to (B, num_heads, T, T)
            # positions with mask = False => set to -inf
            scores = scores.masked_fill(~mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        
        attn = torch.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        
        # reshape back
        out = out.permute(0,2,1,3).contiguous().view(B, T, C)
        out = self.out(out)
        return out

#### 3.2.2 Decoder Block


In [19]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, ff_dim, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadCausalSelfAttention(d_model, num_heads)
        self.norm1 = nn.LayerNorm(d_model)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, d_model),
        )
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, mask=None):
        # Self-attention
        attn_out = self.self_attn(x, mask=mask)
        x = x + self.dropout(attn_out)
        x = self.norm1(x)
        
        # Feed-forward
        ff_out = self.ff(x)
        x = x + self.dropout(ff_out)
        x = self.norm2(x)
        return x

#### 3.2.3 GPT Model


In [20]:
class GPT(nn.Module):
    def __init__(self, vocab_size, d_model=256, num_layers=4, num_heads=4, ff_dim=1024, max_len=512):
        super().__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)
        
        self.blocks = nn.ModuleList([
            DecoderBlock(d_model, num_heads, ff_dim) for _ in range(num_layers)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, vocab_size, bias=False)
    
    def forward(self, x):
        B, T = x.shape
        # Token + positional embeddings
        tok_emb = self.embedding(x)
        positions = torch.arange(0, T, device=x.device).unsqueeze(0)
        pos_emb = self.pos_embedding(positions)
        
        hidden = tok_emb + pos_emb
        
        mask = generate_causal_mask(T).to(x.device)
        for block in self.blocks:
            hidden = block(hidden, mask=mask)
        
        hidden = self.ln_f(hidden)
        logits = self.fc_out(hidden)  # shape (B, T, vocab_size)
        return logits

## <a id="training-routine"></a>3.3 Training Routine (Optimization & Strategies)

We’ll use **AdamW** and **CrossEntropy** (typical for LM). For demonstration, we’ll run a few epochs on the small dataset.


**Possible Training Extensions**:
- **Gradient Accumulation** if memory is limited (accumulate gradients across multiple batches before `optimizer.step()`).
- **Learning Rate Schedules** (cosine decay, linear warmup, etc.).


In [21]:
import torch.optim as optim

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

model = GPT(vocab_size, d_model=256, num_layers=4, num_heads=4).to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for x_batch, y_batch in dataloader:
        x_batch, y_batch = x_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        logits = model(x_batch)  # shape (B, T, vocab_size)
        
        # Flatten
        B, T, V = logits.shape
        loss = criterion(logits.view(B*T, V), y_batch.view(B*T))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / len(dataloader)
    print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}")

Using device: cuda
Epoch 1/5 - Loss: 2.0577
Epoch 2/5 - Loss: 1.6229
Epoch 3/5 - Loss: 1.5208
Epoch 4/5 - Loss: 1.4666
Epoch 5/5 - Loss: 1.4302


## <a id="inference"></a>4. Inference Techniques

Once trained, we can generate text. Key methods:

1. **Greedy Decoding**: pick `argmax` at each step.  
2. **Random Sampling**: sample from the probability distribution (adds variety).  
3. **Temperature Scaling**: modifies distribution sharpness or spread.  
4. **Top-k** (or **Top-p**) sampling: restricts sampling to top-k probable tokens.

### <a id="greedy-random"></a>4.1 Greedy Decoding vs. Random Sampling

- **Greedy Decoding**:  
  - At each time step, select the token with the highest probability.  
  - Pros: consistent, high-probability outcome.  
  - Cons: can get stuck in repetitive loops.

- **Random Sampling**:  
  - Sample next token according to the predicted probability distribution.  
  - Pros: more diverse text, can escape repetitive loops.  
  - Cons: might lead to incoherent tangents if probabilities are too spread out.


### <a id="temperature"></a>4.2 Temperature Scaling

$$
p_i^\text{(scaled)} \propto \exp\left(\frac{\log p_i}{\text{temp}}\right)
$$

- `temp < 1.0` => more confident (peaky distribution).  
- `temp > 1.0` => more creativity and randomness.


### <a id="topk"></a>4.3 Top-k Sampling

- Sort the logits by probability.
- Keep only the top k tokens, set others to 0 (or -inf in log space).
- Renormalize and sample.

*(**Top-p** / nucleus sampling is similar but chooses a dynamic set of tokens until their cumulative probability >= p.)*


### Sample Generation Code

**Observations**:
- Vary **temperature**: higher => more diverse but less coherent.
- Vary **top_k**: small k => conservative, large k => more variety.


In [36]:
###### import torch.nn.functional as F

def generate_text(
    model, 
    start_text="ROMEO:", 
    max_new_tokens=100, 
    temperature=1.0, 
    top_k=None
):
    model.eval()
    
    # Convert start text to indices
    input_ids = torch.tensor([char2idx[ch] for ch in start_text], dtype=torch.long).unsqueeze(0).to(device)
    
    for _ in range(max_new_tokens):
        # Forward pass
        logits = model(input_ids)  # shape: (1, current_len, vocab_size)
        logits = logits[:, -1, :]  # last timestep => shape (1, vocab_size)
        
        # Scale by temperature
        logits = logits / temperature
        
        # (Optional) top-k
        if top_k is not None:
            v, ix = torch.topk(logits, top_k)
            probs = torch.zeros_like(logits).scatter_(1, ix, torch.softmax(v, dim=-1))
        else:
            probs = torch.softmax(logits, dim=-1)
        
        # Sample from distribution
        next_id = torch.multinomial(probs, 1).item()
        
        # Append to input
        input_ids = torch.cat([input_ids, torch.tensor([[next_id]], device=device)], dim=1)
    
    out_seq = input_ids[0].tolist()
    return "".join(idx2char[idx] for idx in out_seq)

# Example usage after training:
generated = generate_text(model, start_text="ROMEO:", max_new_tokens=200, temperature=0.8, top_k=3)
print("Generated Text:\n", generated)

Generated Text:
 ROMEO:
The more sometime and man.
I cannot to the married of thisenckene oronous ouner thes sthe warous therereres, me ale ones ouchest te thalen wishes thaldous wanigedoumathere sthare w winereranous alere


## <a id="exercises"></a>5. Practical Exercises

### Exercise 1: Hyperparameter Tweaks
1. Change `seq_length`, `d_model`, `num_layers`, `num_heads`.  
2. Observe how training speed and memory usage are affected.  
3. Compare generated text for different settings.

### Exercise 2: Gradient Accumulation
1. If you have limited GPU memory, implement a small gradient accumulation loop (accumulate gradients over N mini-batches before calling `optimizer.step()`).  
2. Verify you get similar results to a larger batch size (without accumulation).

### Exercise 3: Inference Experiments
1. Generate text using **greedy decoding**. Save a sample.  
2. Generate text using **random sampling** with `temperature=1.0`. Compare the style.  
3. Try `temperature=0.6` vs. `temperature=1.2` to see how it changes coherence and creativity.  
4. Set `top_k=5` vs. `top_k=50`, observe differences in repetition or diversity.


## <a id="conclusion"></a>6. Conclusion

**Summary**:  
In this **Session 4, Part 1**, we built a **Transformer-based text generator** from scratch, focusing on:
- **Decoder-only architecture** with **causal masking**.
- **Autoregressive training** using `(x,y)` pairs shifted by one token.
- **Inference methods** (greedy vs. random sampling, temperature, top-k) for controlling generation style.

**Key Takeaways**:
1. The **causal mask** is critical for ensuring the model only attends to past tokens.  
2. **Hyperparameters** (hidden size, layers, heads) and training routines (optimizer, batch size) can drastically affect both training speed and generated text quality.  
3. **Inference sampling** parameters can move your model from repetitive to more creative outputs.

**Next Steps**:
- If you’d like to refine generation quality, train longer or on a bigger dataset. 
- For advanced features, see **Session 4, Part 2** on **large-scale pretraining** and **fine-tuning** strategies (pipeline parallelism, mixed precision, etc.).

