In [1]:
import json
import torch
import torch.nn as nn
import os
from tqdm import tqdm
from transformers import BertModel, BertTokenizerFast, AdamW
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import ExponentialLR
import matplotlib.pyplot as plt

# Check and set device to GPU if available, otherwise use CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


  torch.utils._pytree._register_pytree_node(


In [2]:
def load_data(file_path):
    with open(file_path, 'r') as file:
        data = json.load(file)
    
    contexts, questions, answers = [], [], []
    question_count = 0
    
    # Traverse through the JSON structure
    for entry in data['data']:
        for paragraph in entry['paragraphs']:
            context_text = paragraph['context']
            for qa in paragraph['qas']:
                question_text = qa['question']
                question_count += 1
                # Collect each answer associated with the question
                for answer in qa['answers']:
                    contexts.append(context_text.lower())
                    questions.append(question_text.lower())
                    answers.append(answer)
    
    return question_count, contexts, questions, answers


In [3]:
# Load training and validation datasets
train_question_count, train_contexts, train_questions, train_answers = load_data('../spoken_train-v1.1.json')
val_question_count, val_contexts, val_questions, val_answers = load_data('../spoken_test-v1.1.json')


In [4]:
def add_answer_end_position(answers, contexts):
    for answer, context in zip(answers, contexts):
        # Convert answer text to lowercase for consistency
        answer['text'] = answer['text'].lower()
        # Calculate and store the end position of each answer
        answer['answer_end'] = answer['answer_start'] + len(answer['text'])

# Apply the function to both training and validation answers
add_answer_end_position(train_answers, train_contexts)
add_answer_end_position(val_answers, val_contexts)


In [5]:
# Set model-specific parameters
MAX_LENGTH = 512  # Maximum token length for BERT input
MODEL_NAME = "bert-base-uncased"  # Using the BERT base uncased model
DOC_STRIDE = 128  # Stride for handling long contexts

# Initialize tokenizer
tokenizer = BertTokenizerFast.from_pretrained(MODEL_NAME)

# Check if padding is on the right side (needed for tokenization step later)
padding_side = tokenizer.padding_side == "right"


In [6]:
def truncate_contexts(contexts, answers):
    truncated_contexts = []
    for i, context in enumerate(contexts):
        # Check if context length exceeds MAX_LENGTH
        if len(context) > MAX_LENGTH:
            answer_start = answers[i]['answer_start']
            answer_end = answer_start + len(answers[i]['text'])
            mid_point = (answer_start + answer_end) // 2
            
            # Calculate start and end points to center the answer within MAX_LENGTH
            start = max(0, min(mid_point - MAX_LENGTH // 2, len(context) - MAX_LENGTH))
            end = start + MAX_LENGTH
            truncated_contexts.append(context[start:end])
            
            # Adjust answer's start position in the truncated context
            answers[i]['answer_start'] = max(0, answer_start - start)
        else:
            truncated_contexts.append(context)
    
    return truncated_contexts

# Apply truncation to training contexts
train_contexts_truncated = truncate_contexts(train_contexts, train_answers)


In [7]:
# Tokenize and encode the questions and contexts with truncation and padding
train_encodings = tokenizer(
    train_questions,
    train_contexts_truncated,
    max_length=MAX_LENGTH,
    truncation=True,
    stride=DOC_STRIDE,
    padding=True,
    return_offsets_mapping=True  # To help with aligning answer spans later
)

val_encodings = tokenizer(
    val_questions,
    val_contexts,
    max_length=MAX_LENGTH,
    truncation=True,
    stride=DOC_STRIDE,
    padding=True,
    return_offsets_mapping=True
)
    

In [8]:
def find_answer_positions(encodings, answers):
    start_positions = []
    end_positions = []
    
    for i in range(len(answers)):
        answer = answers[i]['text']
        answer_start = answers[i]['answer_start']
        answer_end = answers[i]['answer_end']
        
        # Retrieve the offset mapping for each token in the encoding
        offsets = encodings['offset_mapping'][i]
        
        # Initialize start and end token positions
        start_pos = end_pos = None
        for j, (offset_start, offset_end) in enumerate(offsets):
            if offset_start <= answer_start < offset_end:
                start_pos = j
            if offset_start < answer_end <= offset_end:
                end_pos = j
                break

        # Append found positions or default to (0, 0) if not found
        start_positions.append(start_pos if start_pos is not None else 0)
        end_positions.append(end_pos if end_pos is not None else 0)
    
    return start_positions, end_positions

# Find answer positions in the tokenized data for training and validation sets
train_start_positions, train_end_positions = find_answer_positions(train_encodings, train_answers)
val_start_positions, val_end_positions = find_answer_positions(val_encodings, val_answers)


In [9]:
# Update tokenized encodings with the start and end positions
train_encodings.update({
    'start_positions': train_start_positions,
    'end_positions': train_end_positions
})

val_encodings.update({
    'start_positions': val_start_positions,
    'end_positions': val_end_positions
})


In [10]:
class QADataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        # Returns a dictionary of tensor-encoded data at the specified index
        return {
            'input_ids': torch.tensor(self.encodings['input_ids'][idx]),
            'token_type_ids': torch.tensor(self.encodings['token_type_ids'][idx]),
            'attention_mask': torch.tensor(self.encodings['attention_mask'][idx]),
            'start_positions': torch.tensor(self.encodings['start_positions'][idx]),
            'end_positions': torch.tensor(self.encodings['end_positions'][idx])
        }

    def __len__(self):
        return len(self.encodings['input_ids'])

# Create dataset instances for training and validation
train_dataset = QADataset(train_encodings)
val_dataset = QADataset(val_encodings)


In [11]:
# Create DataLoader instances for training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


In [12]:
# Load the pretrained BERT model
base_bert_model = BertModel.from_pretrained(MODEL_NAME)


In [13]:
class CustomQAModel(nn.Module):
    def __init__(self, base_model):
        super(CustomQAModel, self).__init__()
        self.bert = base_model  # Load the BERT model as a backbone
        self.dropout = nn.Dropout(0.1)
        
        # Define fully connected layers for span prediction
        self.fc1 = nn.Linear(768 * 2, 768 * 2)
        self.fc2 = nn.Linear(768 * 2, 2)  # Output layer for start and end logits
        
        # Sequential layer with dropout, fully connected, and activation
        self.classifier = nn.Sequential(
            self.dropout,
            self.fc1,
            nn.LeakyReLU(),
            self.fc2
        )

    def forward(self, input_ids, attention_mask, token_type_ids):
        # Forward pass through BERT
        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, output_hidden_states=True)
        
        # Concatenate the last and third-last hidden states
        hidden_states = outputs.hidden_states
        combined_hidden = torch.cat((hidden_states[-1], hidden_states[-3]), dim=-1)
        
        # Pass through classifier to get logits
        logits = self.classifier(combined_hidden)
        
        # Split logits for start and end predictions
        start_logits, end_logits = logits.split(1, dim=-1)
        
        # Squeeze to remove the last dimension
        return start_logits.squeeze(-1), end_logits.squeeze(-1)

# Instantiate the custom QA model with the loaded BERT model
model = CustomQAModel(base_bert_model)


In [14]:
def focal_loss(start_logits, end_logits, start_positions, end_positions, gamma=1.0):
    # Softmax for probability calculation
    softmax = nn.Softmax(dim=1)
    probs_start = softmax(start_logits)
    probs_end = softmax(end_logits)
    
    # Inverse probabilities
    inv_probs_start = 1 - probs_start
    inv_probs_end = 1 - probs_end
    
    # Log softmax for focal loss computation
    log_softmax = nn.LogSoftmax(dim=1)
    log_probs_start = log_softmax(start_logits)
    log_probs_end = log_softmax(end_logits)
    
    # Negative log-likelihood for focal loss
    nll_loss = nn.NLLLoss()
    
    # Calculate focal loss for start and end positions
    focal_start = nll_loss(torch.pow(inv_probs_start, gamma) * log_probs_start, start_positions)
    focal_end = nll_loss(torch.pow(inv_probs_end, gamma) * log_probs_end, end_positions)
    
    # Average the two losses
    return (focal_start + focal_end) / 2


In [15]:
# Initialize the AdamW optimizer with weight decay
optimizer = AdamW(model.parameters(), lr=2e-5, weight_decay=2e-2)




In [20]:
def train_model_epoch(model, data_loader, optimizer):
    model.train()  # Set the model to training mode
    epoch_loss = 0
    epoch_accuracy = 0
    batch_count = 0

    for batch in tqdm(data_loader, desc="Training"):
        optimizer.zero_grad()  # Reset gradients
        
        # Move batch data to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)

        # Forward pass to get start and end logits
        start_logits, end_logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        # Compute the focal loss
        loss = focal_loss(start_logits, end_logits, start_positions, end_positions)
        
        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        
        # Track total loss for the epoch
        epoch_loss += loss.item()

        # Calculate batch accuracy
        start_preds = torch.argmax(start_logits, dim=1)
        end_preds = torch.argmax(end_logits, dim=1)
        batch_accuracy = ((start_preds == start_positions).float().mean() + (end_preds == end_positions).float().mean()) / 2
        epoch_accuracy += batch_accuracy.item()

        batch_count += 1

    # Return the average loss and accuracy for the epoch
    avg_loss = epoch_loss / batch_count
    avg_accuracy = epoch_accuracy / batch_count
    return avg_loss, avg_accuracy


In [21]:
def evaluate_model(model, data_loader):
    model.eval()  # Set the model to evaluation mode
    eval_loss = 0
    eval_accuracy = 0
    predictions = []
    references = []
    batch_count = 0

    with torch.no_grad():  # Disable gradient calculation for evaluation
        for batch in tqdm(data_loader, desc="Evaluating"):
            # Move batch data to device
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            token_type_ids = batch['token_type_ids'].to(device)
            start_positions = batch['start_positions'].to(device)
            end_positions = batch['end_positions'].to(device)

            # Forward pass to get start and end logits
            start_logits, end_logits = model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

            # Calculate focal loss
            loss = focal_loss(start_logits, end_logits, start_positions, end_positions)
            eval_loss += loss.item()

            # Calculate accuracy
            start_preds = torch.argmax(start_logits, dim=1)
            end_preds = torch.argmax(end_logits, dim=1)
            batch_accuracy = ((start_preds == start_positions).float().mean() + (end_preds == end_positions).float().mean()) / 2
            eval_accuracy += batch_accuracy.item()

            # Decode predictions and references for comparison
            for i in range(input_ids.size(0)):
                pred_answer = tokenizer.decode(input_ids[i][start_preds[i]:end_preds[i] + 1], skip_special_tokens=True)
                true_answer = tokenizer.decode(input_ids[i][start_positions[i]:end_positions[i] + 1], skip_special_tokens=True)
                predictions.append(pred_answer)
                references.append(true_answer)

            batch_count += 1

    # Calculate average loss and accuracy
    avg_loss = eval_loss / batch_count
    avg_accuracy = eval_accuracy / batch_count
    return avg_loss, avg_accuracy, predictions, references


In [22]:
import jiwer

def calculate_wer(predictions, references):
    # Calculate WER between predictions and references
    wer_score = jiwer.wer(references, predictions)
    return wer_score



In [25]:
EPOCHS = 6

# Move the model to the specified device (GPU or CPU)
model.to(device)

# List to store WER scores for each epoch
wer_scores = []

print("Starting training and evaluation...")

# Loop over the epochs
for epoch in range(EPOCHS):
    print(f"\nEpoch {epoch + 1}/{EPOCHS}")
    
    # Training phase
    train_loss, train_accuracy = train_model_epoch(model, train_loader, optimizer)  # Ensure to pass epoch number
    print(f"Training Loss: {train_loss:.4f}, Training Accuracy: {train_accuracy:.4f}")
    
    # Evaluation phase
    val_loss, val_accuracy, predictions, references = evaluate_model(model, val_loader)
    print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")
    
    # Postprocess predictions and references
    predictions = ["$" if pred == "" else pred for pred in predictions]
    references = ["$" if ref == "" else ref for ref in references]

    # Calculate WER for the validation predictions
    wer_score = calculate_wer(predictions, references)
    wer_scores.append(wer_score)
    print(f"WER for Epoch {epoch + 1}: {wer_score:.4f}")

print("\nFinal WER Scores:", wer_scores)


Starting training and evaluation...

Epoch 1/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 1.2784, Training Accuracy: 0.6066


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.21it/s]


Validation Loss: 2.0385, Validation Accuracy: 0.4520
WER for Epoch 1: 1.2941

Epoch 2/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 0.9991, Training Accuracy: 0.6753


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.26it/s]


Validation Loss: 2.3180, Validation Accuracy: 0.4317
WER for Epoch 2: 1.1413

Epoch 3/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 0.8208, Training Accuracy: 0.7261


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.28it/s]


Validation Loss: 2.5650, Validation Accuracy: 0.4370
WER for Epoch 3: 1.2577

Epoch 4/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 0.6713, Training Accuracy: 0.7715


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.22it/s]


Validation Loss: 2.6840, Validation Accuracy: 0.4469
WER for Epoch 4: 1.4817

Epoch 5/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 0.5428, Training Accuracy: 0.8087


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.25it/s]


Validation Loss: 3.3613, Validation Accuracy: 0.4044
WER for Epoch 5: 1.5595

Epoch 6/6


Training: 100%|██████████| 2320/2320 [05:31<00:00,  7.00it/s]


Training Loss: 0.4169, Training Accuracy: 0.8455


Evaluating: 100%|██████████| 15875/15875 [03:42<00:00, 71.21it/s]


Validation Loss: 3.3431, Validation Accuracy: 0.4266
WER for Epoch 6: 2.1712

Final WER Scores: [1.2941071388696603, 1.1413259495790813, 1.2576647388517965, 1.481700645334167, 1.559542683607619, 2.1712256883192285]
