In [1]:
# Cell 2: Required libraries
import json
import torch
from transformers import BertTokenizerFast, BertForQuestionAnswering, get_linear_schedule_with_warmup
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from accelerate import Accelerator
from fuzzywuzzy import fuzz, process
from tqdm import tqdm
from collections import Counter
import re
import string

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set up device and Accelerator
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
accelerator = Accelerator()

# Use bert-base-uncased fast tokenizer for the task
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased").to(device)
print(device)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda


In [3]:
# Cell 4: Function to load SQuAD-style dataset
def load_squad(path):
    with open(path, 'r', encoding='utf-8') as f:
        squad_dict = json.load(f)

    contexts, questions, answers = [], [], []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                answer_key = 'plausible_answers' if 'plausible_answers' in qa else 'answers'
                for answer in qa[answer_key]:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)
    return contexts, questions, answers



In [4]:
# Cell 5: Load datasets
# -------------------------------------------------------------
# Loads Spoken SQuAD-style JSON files for training and validation,
# including noisy test sets (WER44, WER54).
# -------------------------------------------------------------

train_contexts, train_questions, train_answers = load_squad("spoken_train-v1.1.json")
val_contexts, val_questions, val_answers = load_squad("spoken_test-v1.1.json")
val_contexts_wer44, val_questions_wer44, val_answers_wer44 = load_squad("spoken_test-v1.1_WER44.json")
val_contexts_wer54, val_questions_wer54, val_answers_wer54 = load_squad("spoken_test-v1.1_WER54.json")

print("✅ Datasets loaded successfully.")

✅ Datasets loaded successfully.


In [5]:
# Cell 6: Align answers with exact match and fuzzy fallback
# -------------------------------------------------------------
# Aligns each answer span within its context to ensure the
# start and end indices match correctly, using fuzzy matching
# for ASR-related mismatches.
# -------------------------------------------------------------

def align_answers(answers, contexts):
    for answer, context in zip(answers, contexts):
        if not answer["text"].strip():
            continue
        start_idx = answer["answer_start"]
        end_idx = start_idx + len(answer["text"])

        if context[start_idx:end_idx] == answer["text"]:
            answer["answer_end"] = end_idx
        else:
            possible_matches = [
                context[max(0, start_idx - n):min(len(context), end_idx + n)]
                for n in range(-4, 5)
            ]
            best_match, score = process.extractOne(
                answer["text"], possible_matches, scorer=fuzz.partial_ratio
            )
            if score > 40:
                match_start_idx = context.find(best_match, max(0, start_idx - 4))
                if match_start_idx != -1:
                    answer["answer_start"] = match_start_idx
                    answer["answer_end"] = match_start_idx + len(best_match)
            else:
                answer["answer_end"] = start_idx + len(answer["text"])  # Default

# Apply alignment to all splits
for ans_set, ctx_set in zip(
    [train_answers, val_answers, val_answers_wer44, val_answers_wer54],
    [train_contexts, val_contexts, val_contexts_wer44, val_contexts_wer54]
):
    align_answers(ans_set, ctx_set)

print("✅ Answer alignment complete for all datasets.")



✅ Answer alignment complete for all datasets.


In [6]:
# Cell 7: Tokenization with overlapping windows
# -------------------------------------------------------------
# Tokenizes question–context pairs using BERT tokenizer
# with stride-based overlapping windows for long contexts.
# -------------------------------------------------------------

train_encodings = tokenizer(
    train_contexts, train_questions,
    truncation=True, padding="max_length",
    max_length=384, stride=128,
    return_overflowing_tokens=True, return_offsets_mapping=True
)
val_encodings = tokenizer(
    val_contexts, val_questions,
    truncation=True, padding="max_length",
    max_length=384, stride=128,
    return_overflowing_tokens=True, return_offsets_mapping=True
)
val_encodings_wer44 = tokenizer(
    val_contexts_wer44, val_questions_wer44,
    truncation=True, padding="max_length",
    max_length=384, stride=128,
    return_overflowing_tokens=True, return_offsets_mapping=True
)
val_encodings_wer54 = tokenizer(
    val_contexts_wer54, val_questions_wer54,
    truncation=True, padding="max_length",
    max_length=384, stride=128,
    return_overflowing_tokens=True, return_offsets_mapping=True
)

print("✅ Tokenization complete for all datasets.")

✅ Tokenization complete for all datasets.


In [7]:
# Cell 8: Add token positions for start and end indices
# -------------------------------------------------------------
# Converts character-level answer positions into token-level
# indices for model training.
# -------------------------------------------------------------

def add_token_positions(encodings, answers):
    start_positions, end_positions = [], []
    sample_mapping = encodings["overflow_to_sample_mapping"]
    offset_mapping = encodings["offset_mapping"]

    for i, offsets in enumerate(offset_mapping):
        sample_idx = sample_mapping[i]
        answer = answers[sample_idx]

        if "answer_start" not in answer or "answer_end" not in answer:
            start_positions.append(0)
            end_positions.append(0)
            continue

        start_char = answer["answer_start"]
        end_char = answer["answer_end"]
        start, end = None, None

        for j, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= start_char < offset_end:
                start = j
            if offset_start < end_char <= offset_end:
                end = j
            if start is not None and end is not None:
                break

        if start is None:
            start = tokenizer.model_max_length
        if end is None:
            end = tokenizer.model_max_length

        start_positions.append(start)
        end_positions.append(end)

    encodings.update({"start_positions": start_positions, "end_positions": end_positions})

# Apply to all encoding sets
add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)
add_token_positions(val_encodings_wer44, val_answers_wer44)
add_token_positions(val_encodings_wer54, val_answers_wer54)

print("✅ Token position mapping complete.")

✅ Token position mapping complete.


In [8]:
# Cell 9: Custom Dataset class
# -------------------------------------------------------------
# Wraps tokenized encodings for use in PyTorch DataLoader.
# -------------------------------------------------------------

class SquadDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx] if val[idx] is not None else 0)
                for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings["input_ids"])

# Initialize datasets
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)
val_dataset_wer44 = SquadDataset(val_encodings_wer44)
val_dataset_wer54 = SquadDataset(val_encodings_wer54)

print("✅ Dataset objects created successfully.")

✅ Dataset objects created successfully.


In [9]:
# Cell 10: Model training setup
# -------------------------------------------------------------
# Defines optimizer, scheduler, and prepares dataloaders
# using the Accelerator for efficient device placement.
# -------------------------------------------------------------

optimizer = AdamW(model.parameters(), lr=2e-5)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
num_training_steps = len(train_loader) * 5

scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

model, optimizer, train_loader, scheduler = accelerator.prepare(
    model, optimizer, train_loader, scheduler
)

print("✅ Training configuration ready.")

✅ Training configuration ready.


In [10]:
# Cell 11: Training loop with gradient accumulation
# -------------------------------------------------------------
# Performs multi-epoch fine-tuning of BERT for QA with
# gradient accumulation for memory efficiency.
# -------------------------------------------------------------

gradient_accumulation_steps = 4

for epoch in range(5):
    model.train()
    epoch_loss = 0.0

    for i, batch in enumerate(tqdm(train_loader, desc=f"🚀 Training Epoch {epoch + 1}/5")):
        optimizer.zero_grad()
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        start_positions = batch["start_positions"].to(device)
        end_positions = batch["end_positions"].to(device)

        outputs = model(
            input_ids,
            attention_mask=attention_mask,
            start_positions=start_positions,
            end_positions=end_positions,
        )

        loss = outputs.loss / gradient_accumulation_steps
        accelerator.backward(loss)
        epoch_loss += loss.item()

        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

    print(f"📘 Epoch {epoch + 1}/5 completed | Average Loss: {epoch_loss / len(train_loader):.4f}")

🚀 Training Epoch 1/5: 100%|██████████| 4664/4664 [24:28<00:00,  3.18it/s]


📘 Epoch 1/5 completed | Average Loss: 0.6437


🚀 Training Epoch 2/5: 100%|██████████| 4664/4664 [24:40<00:00,  3.15it/s]


📘 Epoch 2/5 completed | Average Loss: 0.4028


🚀 Training Epoch 3/5: 100%|██████████| 4664/4664 [24:19<00:00,  3.20it/s]


📘 Epoch 3/5 completed | Average Loss: 0.3401


🚀 Training Epoch 4/5: 100%|██████████| 4664/4664 [24:16<00:00,  3.20it/s]


📘 Epoch 4/5 completed | Average Loss: 0.2977


🚀 Training Epoch 5/5: 100%|██████████| 4664/4664 [24:19<00:00,  3.19it/s]

📘 Epoch 5/5 completed | Average Loss: 0.2646





In [11]:
# Cell 12: Normalization and scoring functions
# -------------------------------------------------------------
# Defines helper functions for EM and F1 metric computation.
# -------------------------------------------------------------

def normalize_answer(s):
    def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text)
    def white_space_fix(text): return " ".join(text.split())
    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)
    def lower(text): return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def exact_match_score(prediction, ground_truth):
    return normalize_answer(prediction) == normalize_answer(ground_truth)


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    if num_same == 0:
        return 0.0
    precision = num_same / len(prediction_tokens)
    recall = num_same / len(ground_truth_tokens)
    return 2 * precision * recall / (precision + recall)

In [12]:
# Cell 13: Evaluation
# -------------------------------------------------------------
# Evaluates the fine-tuned model on a given dataset
# using token-based decoding to compute EM and F1 scores.
# -------------------------------------------------------------

def evaluate(model, dataset):
    model.eval()
    answers, references = [], []
    exact_match_total = f1_total = 0

    data_loader = DataLoader(dataset, batch_size=8)
    for batch in tqdm(data_loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            start_preds = torch.argmax(outputs.start_logits, dim=1)
            end_preds = torch.argmax(outputs.end_logits, dim=1)

            for i in range(len(start_preds)):
                start_pred, end_pred = start_preds[i].item(), end_preds[i].item()
                answer_tokens = tokenizer.convert_ids_to_tokens(input_ids[i][start_pred:end_pred + 1])
                answer = tokenizer.convert_tokens_to_string(answer_tokens)

                start_true, end_true = batch["start_positions"][i].item(), batch["end_positions"][i].item()
                ref_tokens = tokenizer.convert_ids_to_tokens(input_ids[i][start_true:end_true + 1])
                reference = tokenizer.convert_tokens_to_string(ref_tokens)

                answers.append(answer)
                references.append(reference)
                exact_match_total += exact_match_score(answer, reference)
                f1_total += f1_score(answer, reference)

    exact_match = 100.0 * exact_match_total / len(answers)
    f1 = 100.0 * f1_total / len(answers)
    print(f"📊 Evaluation → Exact Match: {exact_match:.2f}% | F1 Score: {f1:.2f}%")

In [13]:
print("Evaluating on standard validation set:")
evaluate(model, val_dataset)
print("Evaluating on WER44 validation set:")
evaluate(model, val_dataset_wer44)
print("Evaluating on WER54 validation set:")
evaluate(model, val_dataset_wer54)


Evaluating on standard validation set:


Evaluating: 100%|██████████| 2010/2010 [03:33<00:00,  9.42it/s]


📊 Evaluation → Exact Match: 51.83% | F1 Score: 64.57%
Evaluating on WER44 validation set:


Evaluating: 100%|██████████| 2263/2263 [04:19<00:00,  8.71it/s]


📊 Evaluation → Exact Match: 10.20% | F1 Score: 19.41%
Evaluating on WER54 validation set:


Evaluating: 100%|██████████| 2264/2264 [04:19<00:00,  8.71it/s]

📊 Evaluation → Exact Match: 6.64% | F1 Score: 15.42%





In [14]:
# model.save_pretrained("model")
# tokenizer.save_pretrained("tokenizer")