# Fine-Tuning a Transformer on a Toy Classification Task

We'll use a small dataset (e.g., IMDB-like sentiment) and a simpler Transformer architecture.
This notebook demonstrates the steps for fine-tuning: data loading, tokenization, training, and evaluation.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchtext.legacy import data, datasets

# For simplicity, we could use a built-in dataset from torchtext
# but torchtext API changes frequently. We'll assume older torchtext.

SEED = 1234
torch.manual_seed(SEED)

TEXT = data.Field(tokenize='spacy', lower=True)
LABEL = data.LabelField(dtype=torch.float)

train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
train_data, valid_data = train_data.split(split_ratio=0.8, random_state=torch.Generator().manual_seed(SEED))

# Build vocab (in a real case, much larger)
TEXT.build_vocab(train_data, max_size=10000)
LABEL.build_vocab(train_data)

BATCH_SIZE = 64
train_iterator, valid_iterator = data.BucketIterator.splits(
    (train_data, valid_data),
    batch_size=BATCH_SIZE,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)
test_iterator = data.BucketIterator(
    test_data,
    batch_size=BATCH_SIZE,
    device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
)

# Minimal Transformer-based classifier
class TransformerClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_layers, output_dim, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.pos_embedding = nn.Embedding(5000, embed_dim)  # assume max seq len = 5000 for simplicity

        encoder_layers = []
        for _ in range(num_layers):
            encoder_layers.append(
                nn.TransformerEncoderLayer(
                    d_model=embed_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout
                )
            )
        self.transformer_encoder = nn.TransformerEncoder(nn.Sequential(*encoder_layers), num_layers)

        self.fc = nn.Linear(embed_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, text):
        # text shape: [seq_len, batch_size]
        seq_len, batch_size = text.shape
        positions = (torch.arange(0, seq_len).unsqueeze(1).expand(seq_len, batch_size)).to(text.device)

        embedded = self.embedding(text) + self.pos_embedding(positions)
        # embedded shape: [seq_len, batch_size, embed_dim]

        # Transformer in PyTorch expects [sequence_length, batch_size, embedding_dim]
        transformer_out = self.transformer_encoder(embedded)
        # We'll take the mean of the sequence outputs as a 'sentence embedding'
        pooled = transformer_out.mean(dim=0)
        return self.fc(self.dropout(pooled))

# Initialize model
INPUT_DIM = len(TEXT.vocab)
EMBED_DIM = 128
NUM_HEADS = 4
HIDDEN_DIM = 256
NUM_LAYERS = 2
OUTPUT_DIM = 1  # binary classification
DROPOUT = 0.2

model = TransformerClassifier(INPUT_DIM, EMBED_DIM, NUM_HEADS, HIDDEN_DIM, NUM_LAYERS, OUTPUT_DIM, DROPOUT)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.BCEWithLogitsLoss()
criterion.to(device)

# Training loop (simplified)
EPOCHS = 2
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_iterator:
        optimizer.zero_grad()
        text, text_lengths = batch.text
        predictions = model(text).squeeze(1)
        loss = criterion(predictions, batch.label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Training Loss: {total_loss/len(train_iterator):.4f}")
