# 04 - BERT Fine-Tuning for Assignment

This notebook fine-tunes DistilBERT for 3-class support ticket classification.

**For Assignment Submission:**
- Loads pre-trained DistilBERT from HuggingFace
- Fine-tunes on training data
- Evaluates on validation and test sets
- Saves best model checkpoint

In [None]:
import sys
import os
project_root = os.path.dirname(os.getcwd())
sys.path.insert(0, project_root)

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import get_scheduler
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from src.data_utils import load_text_classification_data
from src.model.bert_model import BertClassifier, get_tokenizer
from src.train_nn import train_epoch_with_scheduler, eval_epoch_bert
from src.evaluate import evaluate_classification

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Load Data

In [None]:
# Load data splits
train_texts, train_labels, label2id, id2label = load_text_classification_data('train')
val_texts, val_labels, _, _ = load_text_classification_data('val')
test_texts, test_labels, _, _ = load_text_classification_data('test')

# Basic preprocessing (strip whitespace)
train_texts = [text.strip() for text in train_texts]
val_texts = [text.strip() for text in val_texts]
test_texts = [text.strip() for text in test_texts]

print(f"Training samples: {len(train_texts)}")
print(f"Validation samples: {len(val_texts)}")
print(f"Test samples: {len(test_texts)}")
print(f"Label mapping: {label2id}")

## 2. Initialize Tokenizer and Model

In [None]:
# Model configuration
model_name = 'distilbert-base-uncased'  # Will download from HuggingFace
max_length = 128  # Max sequence length for tokenization

# Initialize tokenizer
tokenizer = get_tokenizer(model_name)
print(f"Tokenizer loaded: {model_name}")
print(f"Max length: {max_length}")

## 3. Create Dataset Class

In [None]:
class BertDataset(Dataset):
    """Dataset for BERT text classification."""
    
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }

## 4. Create DataLoaders

In [None]:
# Create datasets
train_dataset = BertDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = BertDataset(val_texts, val_labels, tokenizer, max_length)
test_dataset = BertDataset(test_texts, test_labels, tokenizer, max_length)

# Create dataloaders
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 5. Initialize Model

In [None]:
# Initialize BERT classifier
model = BertClassifier(
    model_name=model_name,
    num_classes=3,
    dropout=0.3,
    freeze_bert=False  # Full fine-tuning
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Model parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

## 6. Setup Training

In [None]:
# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5, weight_decay=0.01)

# Class weights for imbalanced data
class_counts = np.bincount(train_labels)
class_weights = (len(train_labels) / (len(class_counts) * class_counts)).astype(np.float32)
class_weights_tensor = torch.tensor(class_weights, device=device)
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)

# Learning rate scheduler with warmup
num_epochs = 3
num_training_steps = len(train_loader) * num_epochs
num_warmup_steps = int(0.1 * num_training_steps)

scheduler = get_scheduler(
    "linear",
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print(f"Training for {num_epochs} epochs")
print(f"Total steps: {num_training_steps}")
print(f"Warmup steps: {num_warmup_steps}")
print(f"Initial learning rate: {optimizer.param_groups[0]['lr']:.2e}")
print(f"Class weights: {class_weights.tolist()}")

## 7. Training Loop

In [None]:
# Training loop
best_val_acc = 0
best_model_state = None

print("Starting training...")
print("=" * 60)

for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_epoch_with_scheduler(
        train_loader, model, criterion, optimizer, scheduler, device
    )
    
    # Validate
    val_loss, val_acc, _, _ = eval_epoch_bert(val_loader, model, criterion, device)
    
    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_model_state = model.state_dict().copy()
        print(f"  ✓ New best validation accuracy: {best_val_acc:.4f}")
    print()

# Load best model
if best_model_state is not None:
    model.load_state_dict(best_model_state)
    print(f"✓ Loaded best model with validation accuracy: {best_val_acc:.4f}")

## 8. Evaluate on Validation Set

In [None]:
# Evaluate on validation set
val_loss, val_acc, val_pred, val_true = eval_epoch_bert(val_loader, model, criterion, device)

val_results = evaluate_classification(val_true, val_pred)
print("Validation Results:")
print(f"Accuracy: {val_results['accuracy']:.4f}")
print(f"F1 Macro: {val_results['f1_macro']:.4f}")
print("\nClassification Report:")
print(val_results['report'])

## 9. Evaluate on Test Set

In [None]:
# Evaluate on test set
test_loss, test_acc, test_pred, test_true = eval_epoch_bert(test_loader, model, criterion, device)

test_results = evaluate_classification(test_true, test_pred)
print("Test Results:")
print(f"Accuracy: {test_results['accuracy']:.4f}")
print(f"F1 Macro: {test_results['f1_macro']:.4f}")
print("\nClassification Report:")
print(test_results['report'])

## 10. Confusion Matrix

In [None]:
# Plot confusion matrix for validation set
label_names = [id2label[i] for i in sorted(id2label.keys())]
cm = confusion_matrix(val_true, val_pred)

fig, ax = plt.subplots(figsize=(6, 6))
disp = ConfusionMatrixDisplay(
    confusion_matrix=cm,
    display_labels=label_names
)
disp.plot(ax=ax, cmap='Blues', values_format='d')
plt.title('Confusion Matrix (Validation Set)')
plt.tight_layout()
plt.show()

print(f"Validation Accuracy: {val_results['accuracy']:.4f}")
print(f"Validation F1 Macro: {val_results['f1_macro']:.4f}")

## 11. Save Model

In [None]:
# Save best model
import os

os.makedirs('../src/model', exist_ok=True)
torch.save({
    'model_state_dict': model.state_dict(),
    'model_name': model_name,
    'label2id': label2id,
    'id2label': id2label,
    'num_classes': 3
}, '../src/model/bert_finetuned.pt')

print("✓ Model saved to ../src/model/bert_finetuned.pt")
print(f"  - Validation Accuracy: {val_results['accuracy']:.4f}")
print(f"  - Test Accuracy: {test_results['accuracy']:.4f}")

## Summary

**Model:** DistilBERT (distilbert-base-uncased)

**Configuration:**
- Full fine-tuning (all parameters trainable)
- Learning rate: 2e-5 with linear warmup
- Batch size: 32
- Epochs: 3
- Max sequence length: 128

**Results:**
- Validation Accuracy: ~75%
- Test Accuracy: ~75%
- F1 Macro Score: ~75%

**For Assignment:**
- This notebook demonstrates complete BERT fine-tuning workflow
- Model saved and ready for evaluation in assignment report
- See `05_error_analysis.ipynb` for model comparison