# Fine-Tuning a Transformer (DistilBERT) for Text Classification with Ignite

This tutorial demonstrates how to fine-tune a pre-trained Transformer model for text classification using PyTorch Ignite.

Transformer models, like BERT and DistilBERT, use a mechanism called **Self-Attention**  to understand the contextual meaning of words across entire sentences, making them highly effective for Natural Language Processing (NLP) tasks.

In this notebook, we will classify IMDB movie reviews to predict whether a review is positive or negative. We will use:
* **Hugging Face `transformers`**: To load the pre-trained DistilBERT model and Tokenizer.
* **Hugging Face `datasets`**: To efficiently download and process the IMDB dataset.
* **PyTorch Ignite**: To manage the training loop, metrics, early stopping, and model checkpointing without writing boilerplate code.

In [None]:
!pip install pytorch-ignite datasets transformers

## 1. Import Libraries & Setup

We begin by importing the necessary modules. We also configure our `device` to utilize a GPU if available, which significantly accelerates Transformer training. Finally, we set a manual seed to ensure our results are reproducible.

In [None]:
import torch
from torch.optim import AdamW
from torch.utils.data import DataLoader
from torch.amp import GradScaler

from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding
from datasets import load_dataset

from ignite.engine import Engine, Events, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss, RunningAverage
from ignite.handlers import ModelCheckpoint, EarlyStopping, global_step_from_engine
from ignite.contrib.handlers import ProgressBar
from ignite.utils import manual_seed

# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

manual_seed(42)

## 2. Data Processing

Preparing text data for a Transformer involves three main steps:
1. **Loading Data:** We stream the IMDB dataset using `load_dataset`.
2. **Tokenization:** Neural networks cannot read raw text. We use `AutoTokenizer` to convert our text strings into integer token IDs that map to the model's vocabulary.
3. **Dynamic Padding:** Models process data in batches, and all sequences in a batch must have the same length. Instead of padding every review to the absolute maximum length of the dataset, `DataCollatorWithPadding` dynamically pads sequences to the maximum length of the *current batch*. This saves significant memory and compute time. It also automatically generates the `attention_mask`, which tells the model which tokens are actual words and which are just padding.

In [None]:
# 1. Load the IMDB dataset
raw_datasets = load_dataset("imdb")

# 2. Load Tokenizer
checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# 3. Tokenize function
def tokenize_function(example):
    return tokenizer(example["text"], truncation=True)

# 4. Apply tokenization to the entire dataset
tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)

# 5. Format for PyTorch
# We remove the raw text and rename 'label' to 'labels' as expected by HF models
tokenized_datasets = tokenized_datasets.remove_columns(["text"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

# 6. Data Collator (Handles dynamic padding)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

# 7. Create DataLoaders
train_dataloader = DataLoader(
    tokenized_datasets["train"], shuffle=True, batch_size=32, collate_fn=data_collator
)
test_dataloader = DataLoader(
    tokenized_datasets["test"], batch_size=32, collate_fn=data_collator
)

## 3. Model Setup

We instantiate `AutoModelForSequenceClassification`. This downloads the pre-trained DistilBERT architecture and its weights, but replaces the original top layer with a randomly initialized classification head tailored for our task (binary classification, hence `num_labels=2`).

We also initialize the `AdamW` optimizer, which is the standard choice for fine-tuning Transformers.

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)
model.to(device)

optimizer = AdamW(model.parameters(), lr=5e-5)

## 4. Training Loop with Ignite

Here we define the core `process_function` (for training) and `eval_function` (for validation) that Ignite's `Engine` will run.

* **Dictionary Inputs:** Hugging Face models expect a dictionary of inputs (e.g., `input_ids`, `attention_mask`) and automatically calculate the loss internally if `labels` are provided. We unpack our batch directly into the model: `model(**batch)`.
* **Automatic Mixed Precision (AMP):** We use `torch.amp.autocast()` and `GradScaler` to run operations in half-precision (`float16`) where safe. This dramatically speeds up training on modern GPUs while reducing memory footprint.

In [None]:
# Using Mixed Precision for speed
scaler = GradScaler()

def process_function(engine, batch):
    model.train()
    optimizer.zero_grad()

    # Move batch dictionary to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Forward pass with AMP
    with torch.amp.autocast('cuda'):
        outputs = model(**batch)
        loss = outputs.loss

    # Backward pass
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

def eval_function(engine, batch):
    model.eval()
    with torch.no_grad():
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        logits = outputs.logits

        # Return (y_pred, y) for Ignite metrics to consume
        return logits, batch["labels"]

# Instantiate Engines
trainer = Engine(process_function)
train_evaluator = Engine(eval_function)
validation_evaluator = Engine(eval_function)

## 5. Metrics and Handlers

PyTorch Ignite's event-driven system allows us to attach metrics and handlers without cluttering our training loop.

* **RunningAverage:** Smooths the loss output during training for easier tracking.
* **Metrics:** We attach `Accuracy` and `Loss` to our evaluators to measure model performance at the end of each epoch.
* **ProgressBar:** Provides a visual tqdm progress bar.
* **EarlyStopping:** Stops training early if the validation accuracy does not improve for 2 consecutive epochs, preventing overfitting.
* **ModelCheckpoint:** Automatically saves the best-performing model weights to disk.

In [None]:
# 1. Running Average of Loss
RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

# 2. Accuracy and Loss Metrics
metrics = {
    'accuracy': Accuracy(),
    'nll': Loss(torch.nn.CrossEntropyLoss())
}

for name, metric in metrics.items():
    metric.attach(train_evaluator, name)
    metric.attach(validation_evaluator, name)

# 3. Progress Bar
pbar = ProgressBar(persist=True, bar_format="")
pbar.attach(trainer, ['loss'])

eval_pbar = ProgressBar(desc="Evaluating", persist=False)
eval_pbar.attach(validation_evaluator)

# 4. Log Validation Results at the end of every epoch
@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
    validation_evaluator.run(test_dataloader)
    metrics = validation_evaluator.state.metrics
    avg_accuracy = metrics['accuracy']
    avg_nll = metrics['nll']

    pbar.log_message(
        f"Validation Results - Epoch: {engine.state.epoch}  "
        f"Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}"
    )

# 5. Early Stopping
def score_function(engine):
    return engine.state.metrics['accuracy']

handler = EarlyStopping(patience=2, score_function=score_function, trainer=trainer)
validation_evaluator.add_event_handler(Events.COMPLETED, handler)

# 6. Model Checkpoint
checkpointer = ModelCheckpoint(
    dirname='/tmp/models',
    filename_prefix='distilbert_imdb',
    n_saved=1,
    create_dir=True,
    require_empty=False,
    score_function=score_function,
    score_name="val_acc",
    global_step_transform=global_step_from_engine(trainer)
)

validation_evaluator.add_event_handler(Events.COMPLETED, checkpointer, {'model': model})

## 6. Run Training

Finally, we execute the training run. Because DistilBERT is already pre-trained on a massive corpus of text, it already understands language structure. We are simply "fine-tuning" it for sentiment analysis, which typically converges to high accuracy in just 2 to 3 epochs.

In [None]:
# Run for 3 epochs
trainer.run(train_dataloader, max_epochs=3)