Importing Important Libraries

In [1]:
from transformers import pipeline, DistilBertTokenizerFast, DistilBertModel, Trainer, TrainingArguments, get_linear_schedule_with_warmup, DefaultDataCollator, AutoModelForQuestionAnswering
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR
import torch.quantization as quantization
from torch.quantization import QuantStub, DeQuantStub
from torch.ao.quantization import float_qparams_weight_only_qconfig
from torch.optim import AdamW
from datasets import Dataset, DatasetDict, load_dataset
from torch import nn
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import accuracy_score, classification_report, f1_score, precision_score, recall_score, precision_recall_curve, roc_curve,roc_auc_score, auc
import matplotlib.pyplot as plt
import time
import numpy as np
import torch
import pandas as pd
import re
import kagglehub

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# SQuAD 1.1 Dataset

In [2]:
squad_data = load_dataset("squad")
print(len(squad_data['train']))
print(len(squad_data['validation']))

87599
10570


Tokenizing

In [4]:
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    inputs = tokenizer(
        questions,
        examples["context"],
        max_length=384,
        truncation="only_second",
        return_offsets_mapping=True,
        padding="max_length",
    )

    offset_mapping = inputs.pop("offset_mapping")
    answers = examples["answers"]
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        answer = answers[i]
        start_char = answer["answer_start"][0]
        end_char = answer["answer_start"][0] + len(answer["text"][0])
        sequence_ids = inputs.sequence_ids(i)

        idx = 0
        while sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
            start_positions.append(0)
            end_positions.append(0)
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)

            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs["start_positions"] = start_positions
    inputs["end_positions"] = end_positions
    return inputs

In [5]:
tokenized_squad = squad_data.map(preprocess_function, batched=True, remove_columns=squad_data["train"].column_names)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

Formatting

In [6]:
tokenized_squad["train"].set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])
tokenized_squad["validation"].set_format(type='torch', columns=['input_ids', 'attention_mask', 'start_positions', 'end_positions'])

In [7]:
train_batch_size = 32
val_batch_size = 32

train_params = {'batch_size': train_batch_size,
                'shuffle': True,
                'num_workers': 0
                }

test_params = {'batch_size': val_batch_size,
                'shuffle': True,
                'num_workers': 0
                }
batch_train = DataLoader(tokenized_squad["train"], **train_params)
batch_val = DataLoader(tokenized_squad["validation"], **test_params)

In [8]:
print(batch_train.dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 87599
})


In [9]:
print(batch_val.dataset)

Dataset({
    features: ['input_ids', 'attention_mask', 'start_positions', 'end_positions'],
    num_rows: 10570
})


In [16]:
class QADistilBERT(nn.Module):
    def __init__(self):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.qa_outputs = nn.Linear(self.distilbert.config.hidden_size, 2)
        
    def forward(self, input_ids, attention_mask=None):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.qa_outputs(sequence_output)
        start_logits, end_logits = logits.split(1, dim=-1)

        return start_logits.squeeze(-1), end_logits.squeeze(-1)
    
model = QADistilBERT()

In [None]:
training_params = {
    'epochs':5,
    'lr':1e-5,
    'optimizer':'AdamW'
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
optimizer = AdamW(model.parameters(), training_params['lr'], weight_decay= 0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

for epoch in range(training_params['epochs']):
    print(f"\nEpoch {epoch+1}")
    start_time = time.time()

    total_train_loss = 0
    model.train()
    for i, batch in enumerate(batch_train):

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        optimizer.zero_grad()
        start_logits, end_logits = model(input_ids=input_ids, attention_mask=attention_mask)
        start_loss = nn.CrossEntropyLoss()(start_logits, start_positions)
        end_loss = nn.CrossEntropyLoss()(end_logits, end_positions)

        loss = (start_loss + end_loss)/2
        loss.backward()

        total_train_loss += loss.item()
        optimizer.step()

        if i % 10 == 0:
            print(f"Batch Loss: {loss:.4f}")

    scheduler.step()
    avg_train_loss = total_train_loss / len(batch_train)

    end_time = time.time()
    epoch_time = end_time - start_time
    mins = int(epoch_time // 60)
    secs = int(epoch_time % 60)

    print(f"Epoch Train Loss: {avg_train_loss:.4f}")
    print(f"Epoch Time: {mins}m {secs}s")

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'final_model.pth')

print("Training complete. Final model saved to 'final_model.pth'.")

Baseline Training

In [None]:
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np

def compute_exact_match(pred_start, pred_end, true_start, true_end):

    return int(pred_start == true_start and pred_end == true_end)

def compute_f1(pred_start, pred_end, true_start, true_end):

    pred_tokens = set(range(pred_start, pred_end + 1))
    true_tokens = set(range(true_start, true_end + 1))
    
    if len(pred_tokens) == 0 or len(true_tokens) == 0:
        return int(pred_tokens == true_tokens)
    
    common = pred_tokens & true_tokens
    if len(common) == 0:
        return 0
    
    precision = len(common) / len(pred_tokens)
    recall = len(common) / len(true_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    
    return f1

def evaluate_qa_model(model, dataloader, device='cuda'):

    model.eval()
    model.to(device)
    
    exact_matches = []
    f1_scores = []
    start_accuracies = []
    end_accuracies = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)
            
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            
            start_logits = outputs.start_logits if hasattr(outputs, 'start_logits') else outputs[0]
            end_logits = outputs.end_logits if hasattr(outputs, 'end_logits') else outputs[1]
            
            pred_starts = torch.argmax(start_logits, dim=1)
            pred_ends = torch.argmax(end_logits, dim=1)
            
            for i in range(len(input_ids)):
                pred_start = pred_starts[i].item()
                pred_end = pred_ends[i].item()
                true_start = start_positions[i].item()
                true_end = end_positions[i].item()
                
                em = compute_exact_match(pred_start, pred_end, true_start, true_end)
                exact_matches.append(em)
                
                f1 = compute_f1(pred_start, pred_end, true_start, true_end)
                f1_scores.append(f1)
                
                start_accuracies.append(int(pred_start == true_start))
                end_accuracies.append(int(pred_end == true_end))
    
    metrics = {
        'exact_match': np.mean(exact_matches) * 100,
        'f1_score': np.mean(f1_scores),
        'start_accuracy': np.mean(start_accuracies) * 100,
        'end_accuracy': np.mean(end_accuracies) * 100,
        'num_samples': len(exact_matches)
    }
    
    return metrics

def print_metrics(metrics):
    print("="*50)
    print("Evaluation Matrices")
    print("="*50)
    print(f"Number of samples: {metrics['num_samples']}")
    print(f"Exact Match (EM):  {metrics['exact_match']:.4f}%")
    print(f"F1 Score:          {metrics['f1_score']:.4f}")
    print(f"Start Accuracy:    {metrics['start_accuracy']:.4f}%")
    print(f"End Accuracy:      {metrics['end_accuracy']:.4f}%")
    print("="*50)



In [18]:
state_dict = torch.load('5_baseline.pth')
model.load_state_dict(state_dict['model_state_dict'])

<All keys matched successfully>

In [19]:
metrics = evaluate_qa_model(model, batch_val, device=device)
print_metrics(metrics)

Evaluating: 100%|██████████| 331/331 [00:56<00:00,  5.89it/s]


EVALUATION RESULTS
Number of samples: 10570
Exact Match (EM):  55.2129%
F1 Score:          0.7364
Start Accuracy:    66.2535%
End Accuracy:      69.9905%






# Quantization

In [None]:
class QADistilBERTQuantized(nn.Module):
    def __init__(self):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained("distilbert-base-uncased")
        self.quant = QuantStub()  
        self.dequant = DeQuantStub() 
        self.qa_outputs = nn.Linear(self.distilbert.config.hidden_size, 2) 

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state

        sequence_output = self.quant(sequence_output)
        
        logits = self.qa_outputs(sequence_output)
        logits = self.dequant(logits)
        start_logits, end_logits = logits.split(1, dim=-1)

        return start_logits.squeeze(-1), end_logits.squeeze(-1)
    
def prepare_for_qat(model):
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    model = torch.quantization.prepare_qat(model, inplace=True)
    return model

model = QADistilBERTQuantized()
model.train()
QAT_model = prepare_for_qat(model)


For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  model = torch.quantization.prepare_qat(model, inplace=True)


QAT Training

In [None]:
training_params = {
    'epochs':5,
    'lr':1e-5,
    'optimizer':'AdamW'
}

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
QAT_model.to(device)
optimizer = AdamW(QAT_model.parameters(), training_params['lr'], weight_decay= 0.01)
scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=0)

for epoch in range(training_params['epochs']):
    print(f"\nEpoch {epoch+1}")
    start_time = time.time()
    total_train_loss = 0
    
    for i, batch in enumerate(batch_train):

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        optimizer.zero_grad()
        start_logits, end_logits = QAT_model(input_ids=input_ids, attention_mask=attention_mask)
        start_loss = nn.CrossEntropyLoss()(start_logits, start_positions)
        end_loss = nn.CrossEntropyLoss()(end_logits, end_positions)

        loss = (start_loss + end_loss)/2
        loss.backward()

        total_train_loss += loss.item()
        optimizer.step()

    scheduler.step()
    avg_train_loss = total_train_loss / len(batch_train)

    end_time = time.time()
    epoch_time = end_time - start_time
    mins = int(epoch_time // 60)
    secs = int(epoch_time % 60)

    print(f"Epoch Train Loss: {avg_train_loss:.4f}")
    print(f"Epoch Time: {mins}m {secs}s")

torch.save({
    'model_state_dict': QAT_model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, '5_quantized.pth')

print("Training complete. Final model saved to '5_quantized.pth'.")

In [None]:
QAT_model.eval()
QAT_model.to('cpu')  # Quantization only works on CPU
quantized_model = torch.quantization.convert(QAT_model, inplace=False)

torch.save(quantized_model.state_dict(), 'quantized_qa_model.pth')

# Evaluating Quantized vs Baseline

Evaluation is done with inference in CPU, since quantization works to speed up inference on CPU device

In [30]:
model = QADistilBERT()
state_dict = torch.load('5_baseline.pth')
model.load_state_dict(state_dict['model_state_dict'])

model.eval()
metrics = evaluate_qa_model(model, batch_val, device='cpu')
print_metrics(metrics)

Evaluating: 100%|██████████| 331/331 [13:38<00:00,  2.47s/it]


EVALUATION RESULTS
Number of samples: 10570
Exact Match (EM):  55.2129%
F1 Score:          0.7364
Start Accuracy:    66.2535%
End Accuracy:      69.9905%






Dynamic Quantization

In [22]:
model = QADistilBERT()
state_dict = torch.load('5_baseline.pth')
model.load_state_dict(state_dict['model_state_dict'])

model.eval()

quantized_model = torch.quantization.quantize_dynamic(
    model,
    {torch.nn.Linear},
    dtype=torch.qint8
)

metrics = evaluate_qa_model(quantized_model, batch_val, device='cpu')
print_metrics(metrics)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  quantized_model = torch.quantization.quantize_dynamic(
Evaluating: 100%|██████████| 331/331 [07:40<00:00,  1.39s/it]


EVALUATION RESULTS
Number of samples: 10570
Exact Match (EM):  46.9442%
F1 Score:          0.6591
Start Accuracy:    59.2148%
End Accuracy:      63.6235%




