# Fine-tuning for Text Classification: Spam Detection

This notebook demonstrates how to fine-tune a pre-trained language model for text classification tasks, specifically focusing on spam detection. By the end of this tutorial, you'll understand the key concepts and implementation steps involved in adapting large language models for classification problems.

## Overview

Fine-tuning a pre-trained language model for text classification involves the following main steps:

1. **Data preparation**: Creating a balanced, tokenized dataset
2. **Model adaptation**: Modifying a pre-trained model for classification tasks
3. **Selective fine-tuning**: Choosing which model parameters to train
4. **Training and evaluation**: Fine-tuning the model and measuring its performance
5. **Inference**: Using the fine-tuned model to classify new texts

Let's implement these steps for a spam classification example.

## Step 1: Setup and Imports

First, let's import the necessary libraries and set up our environment.

In [None]:
import os
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm.notebook import tqdm
import urllib.request
import zipfile
from pathlib import Path
import time
import tiktoken  # OpenAI's tokenizer

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Step 2: Download and Prepare the Dataset

We'll use the SMS Spam Collection dataset, which contains text messages labeled as spam or ham (not spam).

In [None]:
# Download and extract the dataset
def download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path):
    if data_file_path.exists():
        print(f"{data_file_path} already exists. Skipping download and extraction.")
        return

    # Downloading the file
    print(f"Downloading dataset from {url}...")
    with urllib.request.urlopen(url) as response:
        with open(zip_path, "wb") as out_file:
            out_file.write(response.read())

    # Unzipping the file
    print("Extracting dataset...")
    with zipfile.ZipFile(zip_path, "r") as zip_ref:
        zip_ref.extractall(extracted_path)

    # Add .tsv file extension
    original_file_path = Path(extracted_path) / "SMSSpamCollection"
    os.rename(original_file_path, data_file_path)
    print(f"File downloaded and saved as {data_file_path}")

# Set up paths
url = "https://archive.ics.uci.edu/static/public/228/sms+spam+collection.zip"
zip_path = "sms_spam_collection.zip"
extracted_path = "sms_spam_collection"
data_file_path = Path(extracted_path) / "SMSSpamCollection.tsv"

# Create the extraction directory if it doesn't exist
os.makedirs(extracted_path, exist_ok=True)

# Download and extract data
try:
    download_and_unzip_spam_data(url, zip_path, extracted_path, data_file_path)
except (urllib.error.HTTPError, urllib.error.URLError, TimeoutError) as e:
    print(f"Primary URL failed: {e}. Trying backup URL...")
    backup_url = "https://f001.backblazeb2.com/file/LLMs-from-scratch/sms%2Bspam%2Bcollection.zip"
    download_and_unzip_spam_data(backup_url, zip_path, extracted_path, data_file_path)

### Load and explore the dataset

In [None]:
# Load the dataset
df = pd.read_csv(data_file_path, sep="\t", header=None, names=["Label", "Text"])

# Display the first few samples
print(f'Dataset shape: {df.shape}')
df.head()

In [None]:
# Check class distribution
class_distribution = df["Label"].value_counts()
print(class_distribution)

# Visualize class distribution
plt.figure(figsize=(8, 5))
class_distribution.plot(kind='bar')
plt.title('Class Distribution')
plt.xlabel('Class')
plt.ylabel('Count')
plt.xticks(rotation=0)
plt.show()

### Create a balanced dataset

The dataset is imbalanced, with many more non-spam (ham) messages than spam messages. For optimal training, we'll create a balanced dataset by randomly sampling an equal number of ham messages to match the number of spam messages.

In [None]:
def create_balanced_dataset(df):
    # Count the instances of "spam"
    num_spam = df[df["Label"] == "spam"].shape[0]
    
    # Randomly sample "ham" instances to match the number of "spam" instances
    ham_subset = df[df["Label"] == "ham"].sample(num_spam, random_state=123)
    
    # Combine ham "subset" with "spam"
    balanced_df = pd.concat([ham_subset, df[df["Label"] == "spam"]])

    # Shuffle the balanced dataset
    balanced_df = balanced_df.sample(frac=1, random_state=123).reset_index(drop=True)
    
    return balanced_df

balanced_df = create_balanced_dataset(df)
print(balanced_df["Label"].value_counts())

In [None]:
# Convert the string labels to numeric values (0 for ham, 1 for spam)
balanced_df["Label"] = balanced_df["Label"].map({"ham": 0, "spam": 1})
balanced_df.head()

### Split the dataset into training, validation, and test sets

In [None]:
def random_split(df, train_frac, validation_frac):
    # Shuffle the entire DataFrame
    df = df.sample(frac=1, random_state=123).reset_index(drop=True)

    # Calculate split indices
    train_end = int(len(df) * train_frac)
    validation_end = train_end + int(len(df) * validation_frac)

    # Split the DataFrame
    train_df = df[:train_end]
    validation_df = df[train_end:validation_end]
    test_df = df[validation_end:]

    return train_df, validation_df, test_df

# Split data: 70% training, 10% validation, 20% test
train_df, validation_df, test_df = random_split(balanced_df, 0.7, 0.1)

# Save the splits to CSV files for convenience
train_df.to_csv("train.csv", index=None)
validation_df.to_csv("validation.csv", index=None)
test_df.to_csv("test.csv", index=None)

print(f"Training set: {len(train_df)} examples")
print(f"Validation set: {len(validation_df)} examples")
print(f"Test set: {len(test_df)} examples")

## Step 3: Tokenize and prepare the data

Now we need to tokenize the text data and prepare it for the model. Since text messages have different lengths, we'll need to pad shorter messages or truncate longer ones to a consistent length.

In [None]:
# Initialize the tokenizer (using the GPT-2 tokenizer from tiktoken)
tokenizer = tiktoken.get_encoding("gpt2")

# Verify the padding token ID (typically <|endoftext|>)
pad_token_id = tokenizer.encode("<|endoftext|>", allowed_special={"<|endoftext|>"})[0]
print(f"Padding token ID: {pad_token_id}")

In [None]:
# Create a PyTorch Dataset for the spam classification task
class SpamDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length=None, pad_token_id=50256):
        self.data = pd.read_csv(csv_file)

        # Pre-tokenize texts
        self.encoded_texts = [
            tokenizer.encode(text) for text in self.data["Text"]
        ]

        # Find the maximum length if not provided
        if max_length is None:
            self.max_length = self._longest_encoded_length()
        else:
            self.max_length = max_length
            # Truncate sequences if they are longer than max_length
            self.encoded_texts = [
                encoded_text[:self.max_length]
                for encoded_text in self.encoded_texts
            ]

        # Pad sequences to the longest sequence
        self.encoded_texts = [
            encoded_text + [pad_token_id] * (self.max_length - len(encoded_text))
            for encoded_text in self.encoded_texts
        ]

    def __getitem__(self, index):
        encoded = self.encoded_texts[index]
        label = self.data.iloc[index]["Label"]
        return (
            torch.tensor(encoded, dtype=torch.long),
            torch.tensor(label, dtype=torch.long)
        )

    def __len__(self):
        return len(self.data)

    def _longest_encoded_length(self):
        return max(len(encoded_text) for encoded_text in self.encoded_texts)

In [None]:
# Create datasets for training, validation, and testing
train_dataset = SpamDataset(
    csv_file="train.csv",
    tokenizer=tokenizer,
    max_length=None  # Will find the longest sequence in training data
)

print(f"Maximum sequence length in training data: {train_dataset.max_length}")

# Use the same max_length for validation and test for consistency
val_dataset = SpamDataset(
    csv_file="validation.csv",
    tokenizer=tokenizer,
    max_length=train_dataset.max_length
)

test_dataset = SpamDataset(
    csv_file="test.csv",
    tokenizer=tokenizer,
    max_length=train_dataset.max_length
)

In [None]:
# Create DataLoaders for training, validation, and testing
batch_size = 8

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    drop_last=True,  # Drop the last incomplete batch to ensure all batches have the same size
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=batch_size,
    num_workers=0,
    drop_last=False,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    num_workers=0,
    drop_last=False,
)

print(f"{len(train_loader)} training batches")
print(f"{len(val_loader)} validation batches")
print(f"{len(test_loader)} test batches")

## Step 4: Set up the pre-trained model for classification

We'll use a pre-trained GPT-style model and adapt it for text classification. In a real-world scenario, you would load an existing pre-trained model like GPT-2. For this tutorial, we'll create a simplified version that mimics the architecture of GPT-2.

In [None]:
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads, dropout=0.1):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        
        # Create query, key, value projections
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        batch_size = x.shape[0]
        
        # Linear projections and reshape for multi-head attention
        q = self.W_query(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.W_key(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.W_value(x).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
        
        # Calculate attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.head_dim ** 0.5)
        
        # Apply causal mask (for auto-regressive behavior)
        seq_length = scores.size(-1)
        causal_mask = torch.triu(torch.ones(seq_length, seq_length, device=x.device) * float('-inf'), diagonal=1)
        scores = scores + causal_mask
        
        # Apply padding mask if provided
        if mask is not None:
            # Expand mask to match attention scores dimensions
            expanded_mask = mask.unsqueeze(1).unsqueeze(2)
            scores = scores.masked_fill(expanded_mask == 0, float('-inf'))
        
        # Apply softmax and get weighted values
        attention_weights = F.softmax(scores, dim=-1)
        attention_weights = self.dropout(attention_weights)
        context = torch.matmul(attention_weights, v)
        
        # Reshape and project back to original dimensions
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        output = self.out_proj(context)
        
        return output
        
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout)
        )
        
    def forward(self, x):
        return self.layers(x)

class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
        self.feed_forward = FeedForward(d_model, d_ff, dropout)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x, mask=None):
        # Self-attention block with residual connection and layer normalization
        attn_output = self.attention(x, mask)
        x = x + self.dropout(attn_output)
        x = self.norm1(x)
        
        # Feed-forward block with residual connection and layer normalization
        ff_output = self.feed_forward(x)
        x = x + ff_output
        x = self.norm2(x)
        
        return x

class GPTModel(nn.Module):
    def __init__(self, vocab_size=50257, d_model=768, n_layers=12, n_heads=12, 
                 d_ff=3072, max_seq_len=1024, dropout=0.1):
        super().__init__()
        
        # Token and position embeddings
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_seq_len, d_model)
        self.drop_emb = nn.Dropout(dropout)
        
        # Transformer blocks
        self.trf_blocks = nn.Sequential(*[
            TransformerBlock(d_model, n_heads, d_ff, dropout) 
            for _ in range(n_layers)
        ])
        
        # Final layer norm
        self.final_norm = nn.LayerNorm(d_model)
        
        # Output head (initially for next token prediction, will be replaced for classification)
        self.out_head = nn.Linear(d_model, vocab_size, bias=False)
        
        # Initialize weights
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)
            
    def forward(self, idx):
        # Get sequence length and create position indices
        seq_len = idx.size(1)
        pos = torch.arange(0, seq_len, dtype=torch.long, device=idx.device).unsqueeze(0)
        
        # Get token and position embeddings and add them
        tok_embeddings = self.tok_emb(idx)
        pos_embeddings = self.pos_emb(pos)
        x = self.drop_emb(tok_embeddings + pos_embeddings)
        
        # Create attention mask (1 for tokens, 0 for padding)
        mask = (idx != 50256).to(idx.device)  # Use pad_token_id (50256 for GPT-2)
        
        # Pass through transformer blocks
        for block in self.trf_blocks:
            x = block(x, mask)
            
        # Apply final layer norm
        x = self.final_norm(x)
        
        # Apply output head
        logits = self.out_head(x)
        
        return logits

In [None]:
# Initialize the model
# In a real-world scenario, you would load pre-trained weights here
model = GPTModel(
    vocab_size=50257,     # GPT-2 vocabulary size
    d_model=768,          # Embedding dimension
    n_layers=12,          # Number of transformer layers 
    n_heads=12,           # Number of attention heads
    d_ff=3072,            # Feed-forward dimension
    max_seq_len=train_dataset.max_length,  # Max sequence length from our dataset
    dropout=0.1           # Dropout rate
)

# In this tutorial, we're not loading actual pre-trained weights for simplicity.
# In practice, you would load weights from a pre-trained model.

### Adapt the model for text classification

Now we'll modify the model for our binary classification task (spam vs. not spam). We'll:
1. Freeze most of the model parameters
2. Replace the output layer with a binary classification head
3. Unfreeze the last transformer block and final layer norm

In [None]:
# First, freeze all parameters
for param in model.parameters():
    param.requires_grad = False

# Replace the output layer with a binary classification head
num_classes = 2  # Binary classification: spam or not spam
model.out_head = nn.Linear(in_features=768, out_features=num_classes)

# Unfreeze the last transformer block and final layer norm for fine-tuning
for param in model.trf_blocks[-1].parameters():
    param.requires_grad = True

for param in model.final_norm.parameters():
    param.requires_grad = True

# Count trainable parameters vs. total parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params*100:.2f}%)")

### Check model output format

Let's test our model with a small batch of data to make sure the output format is correct for our classification task.

In [None]:
# Move model to the appropriate device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Get a sample batch from the training data loader
sample_inputs, sample_labels = next(iter(train_loader))
sample_inputs = sample_inputs.to(device)

# Set model to evaluation mode
model.eval()

# Forward pass
with torch.no_grad():
    outputs = model(sample_inputs)

print(f"Input shape: {sample_inputs.shape}")  # [batch_size, sequence_length]
print(f"Output shape: {outputs.shape}")       # [batch_size, sequence_length, num_classes]

# For classification, we typically use the output corresponding to the last token
last_token_output = outputs[:, -1, :]
print(f"Last token output shape: {last_token_output.shape}")  # [batch_size, num_classes]

## Step 5: Define Loss and Evaluation Functions

We'll need functions to calculate loss and evaluate the model's accuracy on our dataset.

In [None]:
def calc_loss_batch(input_batch, target_batch, model, device):
    """Calculate loss for a single batch"""
    input_batch, target_batch = input_batch.to(device), target_batch.to(device)
    logits = model(input_batch)[:, -1, :]  # Logits of last output token
    loss = F.cross_entropy(logits, target_batch)
    return loss

def calc_loss_loader(data_loader, model, device, num_batches=None):
    """Calculate average loss over multiple batches"""
    total_loss = 0.
    
    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
        
    for i, (input_batch, target_batch) in enumerate(data_loader):
        if i < num_batches:
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()
        else:
            break
            
    return total_loss / num_batches

def calc_accuracy_loader(data_loader, model, device, num_batches=None):
    """Calculate classification accuracy"""
    model.eval()
    correct_predictions, num_examples = 0, 0

    if num_batches is None:
        num_batches = len(data_loader)
    else:
        num_batches = min(num_batches, len(data_loader))
        
    with torch.no_grad():
        for i, (input_batch, target_batch) in enumerate(data_loader):
            if i < num_batches:
                input_batch, target_batch = input_batch.to(device), target_batch.to(device)
                logits = model(input_batch)[:, -1, :]  # Logits of last output token
                predicted_labels = torch.argmax(logits, dim=1)
                num_examples += predicted_labels.shape[0]
                correct_predictions += (predicted_labels == target_batch).sum().item()
            else:
                break
                
    return correct_predictions / num_examples

def evaluate_model(model, train_loader, val_loader, device, eval_iter):
    """Evaluate model on training and validation sets"""
    model.eval()
    with torch.no_grad():
        train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter)
        val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter)
    model.train()
    return train_loss, val_loss

In [None]:
# Calculate initial (pre-training) accuracy
train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=10)
val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=10)
test_accuracy = calc_accuracy_loader(test_loader, model, device, num_batches=10)

print(f"Initial training accuracy: {train_accuracy*100:.2f}%")
print(f"Initial validation accuracy: {val_accuracy*100:.2f}%")
print(f"Initial test accuracy: {test_accuracy*100:.2f}%")

The initial accuracy should be around 50% (random chance for a balanced binary dataset).

## Step 6: Fine-tune the Model

Now we'll fine-tune our model on the spam classification dataset.

In [None]:
def train_classifier(model, train_loader, val_loader, optimizer, device, num_epochs,
                     eval_freq, eval_iter):
    """Train the classifier model"""
    # Initialize lists to track losses and examples seen
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    examples_seen, global_step = 0, -1

    # Main training loop
    for epoch in range(num_epochs):
        model.train()  # Set model to training mode
        
        # Training phase
        for input_batch, target_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            optimizer.zero_grad()  # Reset gradients
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            loss.backward()  # Calculate gradients
            optimizer.step()  # Update weights
            examples_seen += input_batch.shape[0]
            global_step += 1

            # Evaluate periodically
            if global_step % eval_freq == 0:
                train_loss, val_loss = evaluate_model(model, train_loader, val_loader, device, eval_iter)
                train_losses.append(train_loss)
                val_losses.append(val_loss)
                print(f"Step {global_step:06d}: Train loss {train_loss:.4f}, Val loss {val_loss:.4f}")

        # Calculate accuracy after each epoch
        train_accuracy = calc_accuracy_loader(train_loader, model, device, num_batches=eval_iter)
        val_accuracy = calc_accuracy_loader(val_loader, model, device, num_batches=eval_iter)
        train_accs.append(train_accuracy)
        val_accs.append(val_accuracy)
        
        print(f"Epoch {epoch+1}/{num_epochs}:")
        print(f"  Train accuracy: {train_accuracy*100:.2f}%")
        print(f"  Val accuracy: {val_accuracy*100:.2f}%")

    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs,
        'examples_seen': examples_seen
    }

In [None]:
# Training parameters
num_epochs = 5
learning_rate = 5e-5
weight_decay = 0.1

# Set up optimizer
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),  # Only optimize trainable parameters
    lr=learning_rate,
    weight_decay=weight_decay
)

# Train the model
start_time = time.time()
training_results = train_classifier(
    model, train_loader, val_loader, optimizer, device,
    num_epochs=num_epochs, eval_freq=50, eval_iter=5
)
end_time = time.time()

print(f"Training completed in {(end_time - start_time) / 60:.2f} minutes")

## Step 7: Evaluate the Model

Let's visualize our training progress and evaluate the model's performance on the test set.

In [None]:
# Plot training and validation loss
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(training_results['train_losses'], label='Training Loss')
plt.plot(training_results['val_losses'], label='Validation Loss')
plt.xlabel('Evaluation Step')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()

# Plot training and validation accuracy
plt.subplot(1, 2, 2)
plt.plot(training_results['train_accs'], label='Training Accuracy')
plt.plot(training_results['val_accs'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Evaluate on the full datasets
train_accuracy = calc_accuracy_loader(train_loader, model, device)
val_accuracy = calc_accuracy_loader(val_loader, model, device)
test_accuracy = calc_accuracy_loader(test_loader, model, device)

print(f"Final training accuracy: {train_accuracy*100:.2f}%")
print(f"Final validation accuracy: {val_accuracy*100:.2f}%")
print(f"Final test accuracy: {test_accuracy*100:.2f}%")

In [None]:
# Generate classification report and confusion matrix for the test set
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

def get_predictions(data_loader, model, device):
    model.eval()
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for input_batch, target_batch in data_loader:
            input_batch = input_batch.to(device)
            logits = model(input_batch)[:, -1, :]
            predictions = torch.argmax(logits, dim=1).cpu().numpy()
            all_predictions.extend(predictions)
            all_labels.extend(target_batch.numpy())
    
    return np.array(all_predictions), np.array(all_labels)

# Get predictions for the test set
y_pred, y_true = get_predictions(test_loader, model, device)

# Print classification report
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=['Ham', 'Spam']))

# Generate confusion matrix
cm = confusion_matrix(y_true, y_pred)

# Plot confusion matrix
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Ham', 'Spam'])
plt.yticks(tick_marks, ['Ham', 'Spam'])

# Add text annotations
thresh = cm.max() / 2
for i in range(cm.shape[0]):
    for j in range(cm.shape[1]):
        plt.text(j, i, format(cm[i, j], 'd'),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

## Step 8: Use the Fine-tuned Model for Inference

Now let's use our fine-tuned model to classify some example messages.

In [None]:
def classify_message(text, model, tokenizer, device, max_length):
    """Classify a message as spam or ham"""
    model.eval()
    
    # Tokenize the text
    tokens = tokenizer.encode(text)[:max_length]
    tokens = tokens + [50256] * (max_length - len(tokens))  # Pad to max_length
    input_ids = torch.tensor([tokens], dtype=torch.long).to(device)
    
    # Forward pass
    with torch.no_grad():
        logits = model(input_ids)[:, -1, :]
        probs = F.softmax(logits, dim=1)
        prediction = torch.argmax(probs, dim=1).item()
    
    # Translate prediction to label
    label = "spam" if prediction == 1 else "ham"
    confidence = probs[0][prediction].item()
    
    return label, confidence

In [None]:
# Test on some example messages
examples = [
    "Hey, are we still on for dinner tonight?",
    "URGENT! You have won a $1000 gift card. Click here to claim your prize: http://bit.ly/claim",
    "Remember to pick up milk on your way home.",
    "Congratulations! You've been selected for a FREE iPhone 14. Just pay shipping and handling.",
    "The meeting has been moved to 3pm tomorrow.",
    "FINAL NOTICE: Your car warranty is about to expire. Call now to extend!"
]

for text in examples:
    label, confidence = classify_message(
        text, model, tokenizer, device, max_length=train_dataset.max_length
    )
    print(f"Message: {text}")
    print(f"Classification: {label.upper()} (confidence: {confidence:.4f})")
    print("-" * 80)

## Step 9: Save and Load the Model

Finally, let's save our fine-tuned model so we can reuse it later.

In [None]:
# Save the model
save_path = "spam_classifier_model.pt"
torch.save({
    'model_state_dict': model.state_dict(),
    'max_length': train_dataset.max_length,
    'vocab_size': 50257,
    'd_model': 768,
    'n_heads': 12,
    'n_layers': 12,
    'd_ff': 3072
}, save_path)

print(f"Model saved to {save_path}")

In [None]:
# Function to load the saved model
def load_spam_classifier(model_path, device):
    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    
    # Create a new model with the same architecture
    model = GPTModel(
        vocab_size=checkpoint['vocab_size'],
        d_model=checkpoint['d_model'],
        n_heads=checkpoint['n_heads'],
        n_layers=checkpoint['n_layers'],
        d_ff=checkpoint['d_ff'],
        max_seq_len=checkpoint['max_length']
    )
    
    # Replace the output head for binary classification
    model.out_head = nn.Linear(in_features=checkpoint['d_model'], out_features=2)
    
    # Load the saved state dict
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, checkpoint['max_length']

# Load the saved model and test it
loaded_model, max_length = load_spam_classifier(save_path, device)

# Test the loaded model on an example
example = "URGENT: Your account has been compromised. Call this number to verify your identity."
label, confidence = classify_message(example, loaded_model, tokenizer, device, max_length)

print(f"Message: {example}")
print(f"Classification: {label.upper()} (confidence: {confidence:.4f})")

## Conclusion

In this tutorial, we've demonstrated how to fine-tune a pre-trained language model for text classification using spam detection as an example. We covered:

1. **Data preparation**: Loading, balancing, splitting, and tokenizing the dataset
2. **Model adaptation**: Modifying a pre-trained model for classification by replacing the output layer
3. **Selective fine-tuning**: Freezing most parameters while only training the last block and classification head
4. **Training and evaluation**: Fine-tuning the model and measuring its performance
5. **Inference**: Using the fine-tuned model to classify new text messages

Key insights from this tutorial:

- **Transfer learning**: We leveraged the knowledge in a pre-trained language model instead of training from scratch, which is much more efficient.
- **Selective fine-tuning**: By only fine-tuning a small portion of the model, we preserved its general language understanding while adapting it to our specific task.
- **Causal attention**: We used the last token's representation for classification because it has access to the entire input sequence due to the causal attention mechanism.
- **Classification head**: We replaced the language modeling head with a simple binary classification head.

This approach can be extended to other text classification tasks by adapting the dataset and possibly the model architecture. For more complex tasks, you might want to fine-tune more layers or use larger pre-trained models.