In [None]:
# Standard Library Imports
import ast
import copy
import csv
import json
import math
import os
import re
import time
import warnings
import logging
import random
import collections
from collections import Counter
from typing import List, Tuple, Optional
from IPython.display import HTML, display
import math
import time
from unidecode import unidecode


# Data Handling Libraries
import numpy as np
import pandas as pd
import csv
from torch.utils.data import random_split
import datasets
from datasets import ClassLabel, Sequence, Dataset, DatasetDict

# Data Visualization Libraries
import matplotlib.pyplot as plt
import seaborn as sns
# import scikitplot as skplt  # Uncomment if scikit-plot is installed and needed

# Machine Learning: Model Preparation
from sklearn.metrics import accuracy_score, confusion_matrix, precision_recall_fscore_support
from sklearn.model_selection import cross_val_score, cross_validate, KFold, train_test_split
from sklearn.preprocessing import MinMaxScaler

# Machine Learning: Models and Frameworks
import torch
import evaluate
import xgboost
import wandb
from xgboost import plot_importance  # Uncomment if xgboost importance plot is required


# NLP and Transformers
from transformers import (AdamW, AutoModelForSequenceClassification, AutoModelForQuestionAnswering,
                          AutoTokenizer, CamembertForSequenceClassification, DistilBertConfig,
                          DistilBertForSequenceClassification, DistilBertModel, EarlyStoppingCallback,
                          get_linear_schedule_with_warmup, RobertaForSequenceClassification, EvalPrediction,
                          Trainer, TrainerCallback, TrainingArguments, XLMRobertaForSequenceClassification,
                         DefaultDataCollator, BertForQuestionAnswering, DataCollatorWithPadding, PreTrainedTokenizerFast,
                         default_data_collator)
from datasets import Dataset, DatasetDict, load_dataset
from transformers.trainer_utils import PredictionOutput, speed_metrics

# Experiment Tracking and Optimization Utilities
import optuna
from optuna.trial import TrialState
# import wandb  # Uncomment if using Weights & Biases for experiment tracking

# Progress Bar Utilities
from tqdm.notebook import tqdm


# Data Preparation and Preprocessing

## Data Loading

In [None]:
def dataset_retrieval(size: int = 130319, test_size: float = 0.2):
    squad_dataset = load_dataset("squad_v2", split=f"train[:{size}]")
    train_testvalid = squad_dataset.train_test_split(test_size=test_size)
    test_valid = train_testvalid['test'].train_test_split(test_size=0.5)
    squad_raw = datasets.DatasetDict({
                                    'train': train_testvalid['train'],
                                    'validation': test_valid['train'],
                                    'test': test_valid['test']
                                    })
    return squad_raw

squad_raw = dataset_retrieval()

## Retrieved Balance Dataset

In [None]:
def balance_dataset(squad_raw, target_size: int = 4000, num_bins: int = 10):
    train_data = squad_raw['train']
    id_to_idx = {sample['id']: idx for idx, sample in enumerate(train_data)}
    
    # Create a DataFrame with ID and answer text length
    df = pd.DataFrame({
        'id': [sample['id'] for sample in train_data],
        'answer_length': [len(sample['answers']['text'][0]) if sample['answers']['text'] else 0 for sample in train_data]
    })
    
    # Divide the data into bins based on answer length
    bins = np.linspace(0, df['answer_length'].max(), num_bins + 1)
    bin_indices = np.digitize(df['answer_length'], bins)
    df['bin'] = bin_indices
    
    # Determine the balanced count for each bin
    bin_counts = df.groupby('bin')['bin'].count()
    balanced_bin_count = target_size // num_bins
    
    # Balance the bins
    balanced_ids = []
    for bin_idx in range(1, num_bins):
        bin_data = df[df['bin'] == bin_idx]
        if len(bin_data) > balanced_bin_count:
            balanced_ids.extend(bin_data.sample(balanced_bin_count, replace=False)['id'].tolist())
        else:
            balanced_ids.extend(bin_data['id'].tolist())
    
    # Handle empty answer examples separately
    empty_answer_data = df[df['answer_length'] == 0]
    desired_empty_answer_count = target_size // num_bins
    if len(empty_answer_data) > desired_empty_answer_count:
        max_empty_answer_count = min(desired_empty_answer_count, len(empty_answer_data))
        balanced_ids.extend(empty_answer_data.sample(max_empty_answer_count, replace=False)['id'].tolist())
    else:
        balanced_ids.extend(empty_answer_data['id'].tolist())
    
    # Ensure the balanced dataset size is equal to the target size
    if len(balanced_ids) > target_size:
        balanced_ids = balanced_ids[:target_size]
    elif len(balanced_ids) < target_size:
        remaining_count = target_size - len(balanced_ids)
        remaining_ids = df[~df['id'].isin(balanced_ids)]['id'].tolist()
        balanced_ids.extend(np.random.choice(remaining_ids, size=remaining_count, replace=False))
    
    balanced_indices = [id_to_idx[id] for id in balanced_ids]
    
    val_data = squad_raw['validation']
    val_target_size = int(target_size * 0.2)
    val_indices = random.sample(range(len(val_data)), val_target_size)
    val_dataset = val_data.select(val_indices)
    
    # Create a new DatasetDict with the balanced dataset
    balanced_dataset = train_data.select(balanced_indices)
    balanced_squad_raw = DatasetDict({
        'train': balanced_dataset,
        'validation': val_dataset
    })
    
    return balanced_squad_raw

In [None]:
def plot_answer_length_distribution_char(dataset, bin_length: int = 20):
    answer_lengths = []
    for sample in dataset:
        answer_text = sample['answers']['text']
        if answer_text:
            answer_lengths.append(len(answer_text[0]))
        else:
            answer_lengths.append(0)
    
    print(f'Mean answer length: {np.mean(answer_lengths):.2f}')
    print(f'Highest answer length: {np.max(answer_lengths)}')
    print(f"empty list to total length ratio: {(len([i for i in answer_lengths if i == 0])/len(answer_lengths)) * 100}%")
    # Create a figure and axis
    fig, ax = plt.subplots(figsize=(10, 6))
    
    # Plot the data as a bar chart
    ax.hist(answer_lengths, bins=bin_length, edgecolor='black')
    
    # Set the title and labels
    ax.set_title('Distribution of Answer Text Lengths')
    ax.set_xlabel('Answer Text Length')
    ax.set_ylabel('Count')
    
    # Display the plot
    plt.show()

In [None]:
plot_answer_length_distribution_char(squad_raw['train'])

In [None]:
balance_dataset = balance_dataset(squad_raw, 1000, 10)

In [None]:
plot_answer_length_distribution_char(balance_dataset['train'])

##  Training and Validation Data Preparation and Tokenisation

In [None]:
max_length = 512  # Adjust max_length to fit model constraints
global global_doc_stride
global_doc_stride = 64
pretrained_model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name)
assert isinstance(tokenizer, PreTrainedTokenizerFast)
pad_on_right = tokenizer.padding_side == 'right'
right_padding = pad_on_right

In [None]:
def long_sequence_filter(examples: DatasetDict, max_length: int = 480):
    filtered_datasets = {}

    for split, dataset in examples.items():
        questions = dataset["question"]
        contexts = dataset["context"]

        max_tokenized_length = tokenizer.model_max_length
        filtered_examples = []

        for i, (q, c) in enumerate(zip(questions, contexts)):
            # Estimate the length of the tokenized sequence
            estimated_length = len(tokenizer.encode(q, add_special_tokens=True)) + len(tokenizer.encode(c, add_special_tokens=False)) + 3

            # Check if the estimated length exceeds the desired maximum length
            if estimated_length < max_tokenized_length:
                filtered_examples.append(dataset[i])
            else:
                print(f"Skipping example with context length {len(c)} and question length {len(q)}")
                print(f"Row {i} Estimated tokenized length: {estimated_length} (max allowed: {max_tokenized_length})")

        print(f"Filtered out {len(dataset) - len(filtered_examples)} examples due to length for {split} set.")

        filtered_datasets[split] = Dataset.from_dict({
            'id': [example['id'] for example in filtered_examples],
            'title': [example['title'] for example in filtered_examples],
            'context': [example['context'] for example in filtered_examples],
            'question': [example['question'] for example in filtered_examples],
            'answers': [example['answers'] for example in filtered_examples]
        })

    filtered_dataset_dict = DatasetDict(filtered_datasets)

    return filtered_dataset_dict

# filtered_dataset = dataset.filter(lambda example: len(example['context']) <= max_length)

In [None]:
def preprocess_function(examples):
    questions = [q.strip() for q in examples["question"]]
    tokenized_examples = tokenizer(
        questions,
        examples["context"],
        max_length=max_length,  # Updated to handle model max length
        truncation="only_second" if pad_on_right else "only_first",
        stride=global_doc_stride,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )
    
    offset_mapping = tokenized_examples.pop("offset_mapping")
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    answers = examples["answers"]
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # We will label impossible answers with the index of the CLS token.
        input_ids = tokenized_examples["input_ids"][i]
        if tokenizer.cls_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.cls_token_id)
        elif tokenizer.bos_token_id in input_ids:
            cls_index = input_ids.index(tokenizer.bos_token_id)
        else:
            cls_index = 0
        
        # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        
        # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        answers = examples['answers'][sample_index]
        # If no answers are given, set the cls_index as answer.
        if len(answers["answer_start"]) == 0:
            tokenized_examples["start_positions"].append(cls_index)
            tokenized_examples["end_positions"].append(cls_index)
        else:
            # Start/end character index of the answer in the text.
            # Start/end character index of the answer in the text.
            start_char = answers["answer_start"][0]
            end_char = start_char + len(answers["text"][0])

            # Start token index of the current span in the text.
            token_start_index = 0
            while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
                token_start_index += 1

            # End token index of the current span in the text.
            token_end_index = len(input_ids) - 1
            while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
                token_end_index -= 1

            # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
            if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
                tokenized_examples["start_positions"].append(cls_index)
                tokenized_examples["end_positions"].append(cls_index)
            else:
                # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
                # Note: we could go after the last offset if the answer is the last word (edge case).
                while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
                    token_start_index += 1
                tokenized_examples["start_positions"].append(token_start_index - 1)
                while offsets[token_end_index][1] >= end_char:
                    token_end_index -= 1
                tokenized_examples["end_positions"].append(token_end_index + 2)
                            

    return tokenized_examples
        

In [None]:
# start_char = answers["answer_start"][0]
# end_char = start_char + len(answers["text"][0])

# # Start token index of the current span in the text.
# token_start_index = 0
# while sequence_ids[token_start_index] != (1 if pad_on_right else 0):
#     token_start_index += 1

# # End token index of the current span in the text.
# token_end_index = len(input_ids) - 1
# while sequence_ids[token_end_index] != (1 if pad_on_right else 0):
#     token_end_index -= 1

# # Detect if the answer is out of the span (in which case this feature is labeled with the CLS index).
# if not (offsets[token_start_index][0] <= start_char and offsets[token_end_index][1] >= end_char):
#     tokenized_examples["start_positions"].append(cls_index)
#     tokenized_examples["end_positions"].append(cls_index)
# else:
#     # Otherwise move the token_start_index and token_end_index to the two ends of the answer.
#     # Note: we could go after the last offset if the answer is the last word (edge case).
#     while token_start_index < len(offsets) and offsets[token_start_index][0] <= start_char:
#         token_start_index += 1
#     tokenized_examples["start_positions"].append(token_start_index - 1)
#     while offsets[token_end_index][1] >= end_char:
#         token_end_index -= 1
#     tokenized_examples["end_positions"].append(token_end_index + 2)

In [None]:
def prepare_validation_features(examples):
  # Some of the questions have lots of whitespace on the left, which is not useful and will make the
  # truncation of the context fail (the tokenized question will take a lots of space). So we remove that
  # left whitespace
    examples["question"] = [q.lstrip() for q in examples["question"]]
  # Tokenize our examples with truncation and maybe padding, but keep the overflows using a stride. This results
  # in one example possible giving several features when a context is long, each of those features having a
  # context that overlaps a bit the context of the previous feature.
    tokenized_examples = tokenizer(
      examples["question"],
      examples["context"],
      truncation="only_second" if right_padding else "only_first",
      max_length=max_length,
      stride=global_doc_stride,
      return_overflowing_tokens=True,
      return_offsets_mapping=True,
      padding="max_length",
  )

  # Since one example might give us several features if it has a long context, we need a map from a feature to
  # its corresponding example. This key gives us just that.
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")

  # For evaluation, we will need to convert our predictions to substrings of the context, so we keep the
  # corresponding example_id and we will store the offset mappings.
    tokenized_examples["example_id"] = []

    for i in range(len(tokenized_examples["input_ids"])):
      # Grab the sequence corresponding to that example (to know what is the context and what is the question).
        sequence_ids = tokenized_examples.sequence_ids(i)
        context_index = 0

      # One example can give several spans, this is the index of the example containing this span of text.
        sample_index = sample_mapping[i]
        tokenized_examples["example_id"].append(examples["id"][sample_index])

      # Set to None the offset_mapping that are not part of the context so it's easy to determine if a token
      # position is part of the context or not.
        tokenized_examples["offset_mapping"][i] = [
            (o if sequence_ids[k] == context_index else None)
            for k, o in enumerate(tokenized_examples["offset_mapping"][i])
      ]

    return tokenized_examples

## Applying Preprocessing

Run the below line to filter (remove long sequence) out the data

In [None]:
# squad_raw = long_sequence_filter(squad_raw)

filtered_dataset = long_sequence_filter(balance_dataset)

In [None]:
global train_dataset
global eval_dataset
## APPLY PREPROCESSING
def apply_preprocessing(dataset):
    train_dataset = dataset['train'].map(
                preprocess_function,
                batched=True,
                remove_columns=dataset["train"].column_names,
                desc="Running tokenizer on train dataset",
            )
    eval_dataset = dataset['validation'].map(
                prepare_validation_features,
                batched=True,
                remove_columns=dataset["train"].column_names,
                desc="Running tokenizer on validation dataset",
            )
    eval_examples =  dataset["validation"]
    return train_dataset, eval_dataset, eval_examples

In [None]:
train_dataset, eval_dataset, eval_examples = apply_preprocessing(filtered_dataset)

In [None]:
def retrieve_and_compare_answers(dataset, examples, split_type='train'):
    answer_mismatches = []
    counter = 0
    for i, example in enumerate(examples):
        # Retrieve stored start and end positions
        start_pos = dataset[i]['start_positions']
        end_pos = dataset[i]['end_positions']
        
        # Fetch the context and calculate predicted answer text
        context = example['context']
        question = example['question']
        answer_text = tokenizer.decode(dataset[i]['input_ids'][start_pos:end_pos])
        
        # Normalize and compare with actual answer
        actual_answer = example['answers']['text'][0] if example['answers']['text'] else ""
        normalized_actual_answer = unidecode(actual_answer.lower().replace(" ", ""))
        normalized_predicted_answer = unidecode(answer_text.lower().replace(" ", ""))
        
        if normalized_actual_answer != normalized_predicted_answer:
            counter += 1
            answer_mismatches.append({
                'Raw Example ID': example['id'],
                'Raw Question': question,
                'Decoded Question': question,
                'Decoded Answer': answer_text,
                'Actual Answer': actual_answer,
                'Raw Context': context[:200] + '...'  # Truncating context for display purposes
            })
    
    # Convert list of dictionaries to DataFrame
    mismatches_df = pd.DataFrame(answer_mismatches)
    print(f"Number of mismatches in {split_type}: {len(answer_mismatches)}")
    return mismatches_df

# Retrieve mismatches for training and validation datasets
train_mismatches = retrieve_and_compare_answers(train_dataset, filtered_dataset['train'], 'train')
# Uncomment and adjust as needed for validation dataset
#validation_mismatches = retrieve_and_compare_answers(eval_dataset, squad_raw['validation'], 'validation')

# Optionally, display some mismatches in a DataFrame
display(train_mismatches.tail(50))  # Display first 12 mismatches from training
#display(validation_mismatches.head(3))  # Display first 3 mismatches from validation if needed


# Model Structure

## Architecture

In [None]:
#wandb.login(key='8f7092f0fdaf14add2b4cc07cb0e740080cdd8e7')
wandb.login()

In [None]:
class LoggingCallback(TrainerCallback):
    def __init__(self, log_path):
        self.log_path = log_path
    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            with open(self.log_path, "a") as f:
                f.write(json.dumps(logs) + "\n")

### Compute_metrics function for Question and Answering problem is different to classification, more preocessing required.

metric = evaluate.load("squad_v2")

def compute_metrics(p: EvalPrediction):
        return metric.compute(predictions=p.predictions, references=p.label_ids)

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)} is available.")
else:
    print("No GPU available. Training will run on CPU.")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [None]:
"""
A subclass of `Trainer` specific to Question-Answering tasks
"""
try:
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    has_tpu = True
except ImportError:
    has_tpu = False

class QuestionAnsweringTrainer(Trainer):
    def __init__(self, *args, eval_examples=None, post_process_function=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.eval_examples = eval_examples
        self.post_process_function = post_process_function

    def evaluate(self, eval_dataset=None, eval_examples=None, ignore_keys=None, metric_key_prefix: str = "eval"):
        eval_dataset = self.eval_dataset if eval_dataset is None else eval_dataset
        eval_dataloader = self.get_eval_dataloader(eval_dataset)
        eval_examples = self.eval_examples if eval_examples is None else eval_examples

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        start_time = time.time()
        try:
            output = eval_loop(
                eval_dataloader,
                description="Evaluation",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )
        if self.post_process_function is not None and self.compute_metrics is not None and self.args.should_save:
            # Only the main node write the results by default
            eval_preds = self.post_process_function(eval_examples, eval_dataset, output.predictions)
            metrics = self.compute_metrics(eval_preds)

            # Prefix all keys with metric_key_prefix + '_'
            for key in list(metrics.keys()):
                if not key.startswith(f"{metric_key_prefix}_"):
                    metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
            metrics.update(output.metrics)
        else:
            metrics = output.metrics

        if self.args.should_log:
            # Only the main node log the results by default
            self.log(metrics)

        if (self.args.tpu_metrics_debug and has_tpu) or (self.args.debug and has_tpu):
            # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
            xm.master_print(met.metrics_report())

        self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, metrics)
        return metrics

    def predict(self, predict_dataset, predict_examples, ignore_keys=None, metric_key_prefix: str = "test"):
        predict_dataloader = self.get_test_dataloader(predict_dataset)

        # Temporarily disable metric computation, we will do it in the loop here.
        compute_metrics = self.compute_metrics
        self.compute_metrics = None
        eval_loop = self.prediction_loop if self.args.use_legacy_prediction_loop else self.evaluation_loop
        start_time = time.time()
        try:
            output = eval_loop(
                predict_dataloader,
                description="Prediction",
                # No point gathering the predictions if there are no metrics, otherwise we defer to
                # self.args.prediction_loss_only
                prediction_loss_only=True if compute_metrics is None else None,
                ignore_keys=ignore_keys,
                metric_key_prefix=metric_key_prefix,
            )
        finally:
            self.compute_metrics = compute_metrics
        total_batch_size = self.args.eval_batch_size * self.args.world_size
        if f"{metric_key_prefix}_jit_compilation_time" in output.metrics:
            start_time += output.metrics[f"{metric_key_prefix}_jit_compilation_time"]
        output.metrics.update(
            speed_metrics(
                metric_key_prefix,
                start_time,
                num_samples=output.num_samples,
                num_steps=math.ceil(output.num_samples / total_batch_size),
            )
        )

        if self.post_process_function is None or self.compute_metrics is None:
            return output

        predictions = self.post_process_function(predict_examples, predict_dataset, output.predictions, "predict")
        metrics = self.compute_metrics(predictions)

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)
        metrics.update(output.metrics)
        return PredictionOutput(predictions=predictions.predictions, label_ids=predictions.label_ids, metrics=metrics)

In [None]:
class AdvancedEarlyStoppingCallback(TrainerCallback):
    """
    A callback to stop training when either the performance falls below a certain threshold
    or if there is no improvement over a set number of epochs.
    """
    def __init__(self, metric_name, patience, threshold):
        self.metric_name = metric_name
        self.patience = patience
        self.threshold = threshold
        self.best_score = None
        self.no_improve_epochs = 0

    def on_evaluate(self, args, state, control, **kwargs):
        metric_value = kwargs['metrics'].get(self.metric_name)

        if self.best_score is None or metric_value > self.best_score:
            self.best_score = metric_value
            self.no_improve_epochs = 0
        else:
            self.no_improve_epochs += 1

        # Check if performance is below the threshold
        if metric_value < self.threshold:
            control.should_training_stop = True
            print(f"Stopping training: {self.metric_name} below threshold of {self.threshold}")

        # Check if no improvement has been seen over the allowed patience
        if self.no_improve_epochs >= self.patience:
            control.should_training_stop = True
            print(f"Stopping training: No improvement in {self.metric_name} for {self.patience} epochs")


In [None]:
data_collator = DefaultDataCollator()

## Post processing QA

In [104]:
logger = logging.getLogger(__name__)

def postprocess_qa_predictions(
    examples,
    features,
    predictions: Tuple[np.ndarray, np.ndarray],
    version_2_with_negative: bool = True,
    n_best_size: int = 20,
    max_answer_length: int = 30,
    null_score_diff_threshold: float = 0.0,
    output_dir: Optional[str] = None,
    prefix: Optional[str] = None,
    log_level: Optional[int] = logging.WARNING,
):
    """
    Post-processes the predictions of a question-answering model to convert them to answers that are substrings of the
    original contexts. This is the base postprocessing functions for models that only return start and end logits.

    Args:
        examples: The non-preprocessed dataset (see the main script for more information).
        features: The processed dataset (see the main script for more information).
        predictions (:obj:`Tuple[np.ndarray, np.ndarray]`):
            The predictions of the model: two arrays containing the start logits and the end logits respectively. Its
            first dimension must match the number of elements of :obj:`features`.
        version_2_with_negative (:obj:`bool`, `optional`, defaults to :obj:`False`):
            Whether or not the underlying dataset contains examples with no answers.
        n_best_size (:obj:`int`, `optional`, defaults to 20):
            The total number of n-best predictions to generate when looking for an answer.
        max_answer_length (:obj:`int`, `optional`, defaults to 30):
            The maximum length of an answer that can be generated. This is needed because the start and end predictions
            are not conditioned on one another.
        null_score_diff_threshold (:obj:`float`, `optional`, defaults to 0):
            The threshold used to select the null answer: if the best answer has a score that is less than the score of
            the null answer minus this threshold, the null answer is selected for this example (note that the score of
            the null answer for an example giving several features is the minimum of the scores for the null answer on
            each feature: all features must be aligned on the fact they `want` to predict a null answer).

            Only useful when :obj:`version_2_with_negative` is :obj:`True`.
        output_dir (:obj:`str`, `optional`):
            If provided, the dictionaries of predictions, n_best predictions (with their scores and logits) and, if
            :obj:`version_2_with_negative=True`, the dictionary of the scores differences between best and null
            answers, are saved in `output_dir`.
        prefix (:obj:`str`, `optional`):
            If provided, the dictionaries mentioned above are saved with `prefix` added to their names.
        log_level (:obj:`int`, `optional`, defaults to ``logging.WARNING``):
            ``logging`` log level (e.g., ``logging.WARNING``)
    """
    if len(predictions) != 2:
        raise ValueError("`predictions` should be a tuple with two elements (start_logits, end_logits).")
    all_start_logits, all_end_logits = predictions

    if len(predictions[0]) != len(features):
        raise ValueError(f"Got {len(predictions[0])} predictions and {len(features)} features.")

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    features_per_example = collections.defaultdict(list)
    for i, feature in enumerate(features):
        features_per_example[example_id_to_index[feature["example_id"]]].append(i)

    # The dictionaries we have to fill.
    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()
    if version_2_with_negative:
        scores_diff_json = collections.OrderedDict()

    # Logging.
    logger.setLevel(log_level)
    logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.")

    # Let's loop over all the examples!
    for example_index, example in enumerate(tqdm(examples)):
        # Those are the indices of the features associated to the current example.
        feature_indices = features_per_example[example_index]

        min_null_prediction = None
        prelim_predictions = []

        # Looping through all the features associated to the current example.
        for feature_index in feature_indices:
            # We grab the predictions of the model for this feature.
            start_logits = all_start_logits[feature_index]
            end_logits = all_end_logits[feature_index]
            # This is what will allow us to map some the positions in our logits to span of texts in the original
            # context.
            offset_mapping = features[feature_index]["offset_mapping"]
            # Optional `token_is_max_context`, if provided we will remove answers that do not have the maximum context
            # available in the current feature.
            token_is_max_context = features[feature_index].get("token_is_max_context", None)

            # Update minimum null prediction.
            feature_null_score = start_logits[0] + end_logits[0]
            if min_null_prediction is None or min_null_prediction["score"] > feature_null_score:
                min_null_prediction = {
                    "offsets": (0, 0),
                    "score": feature_null_score,
                    "start_logit": start_logits[0],
                    "end_logit": end_logits[0],
                }

            # Go through all possibilities for the `n_best_size` greater start and end logits.
            start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist()
            end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Don't consider out-of-scope answers, either because the indices are out of bounds or correspond
                    # to part of the input_ids that are not in the context.
                    if (
                        start_index >= len(offset_mapping)
                        or end_index >= len(offset_mapping)
                        or offset_mapping[start_index] is None
                        or len(offset_mapping[start_index]) < 2
                        or offset_mapping[end_index] is None
                        or len(offset_mapping[end_index]) < 2
                    ):
                        continue
                    # Don't consider answers with a length that is either < 0 or > max_answer_length.
                    if end_index < start_index or end_index - start_index + 1 > max_answer_length:
                        continue
                    # Don't consider answer that don't have the maximum context available (if such information is
                    # provided).
                    if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False):
                        continue

                    prelim_predictions.append(
                        {
                            "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]),
                            "score": start_logits[start_index] + end_logits[end_index],
                            "start_logit": start_logits[start_index],
                            "end_logit": end_logits[end_index],
                        }
                    )
        if version_2_with_negative and min_null_prediction is not None:
            # Add the minimum null prediction
            prelim_predictions.append(min_null_prediction)
            null_score = min_null_prediction["score"]

        # Only keep the best `n_best_size` predictions.
        predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size]

        # Add back the minimum null prediction if it was removed because of its low score.
        if (
            version_2_with_negative
            and min_null_prediction is not None
            and not any(p["offsets"] == (0, 0) for p in predictions)
        ):
            predictions.append(min_null_prediction)

        # Use the offsets to gather the answer text in the original context.
        context = example["context"]
        for pred in predictions:
            offsets = pred.pop("offsets")
            pred["text"] = context[offsets[0] : offsets[1]]

        # In the very rare edge case we have not a single non-null prediction, we create a fake prediction to avoid
        # failure.
        if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""):
            predictions.insert(0, {"text": "empty", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0})

        # Compute the softmax of all scores (we do it with numpy to stay independent from torch/tf in this file, using
        # the LogSumExp trick).
        scores = np.array([pred.pop("score") for pred in predictions])
        exp_scores = np.exp(scores - np.max(scores))
        probs = exp_scores / exp_scores.sum()

        # Include the probabilities in our predictions.
        for prob, pred in zip(probs, predictions):
            pred["probability"] = prob

        # Pick the best prediction. If the null answer is not possible, this is easy.
        if not version_2_with_negative:
            all_predictions[example["id"]] = predictions[0]["text"]
        else:
            # Otherwise we first need to find the best non-empty prediction.
            i = 0
            while predictions[i]["text"] == "":
                i += 1
            best_non_null_pred = predictions[i]

            # Then we compare to the null prediction using the threshold.
            score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"]
            scores_diff_json[example["id"]] = float(score_diff)  # To be JSON-serializable.
            if score_diff > null_score_diff_threshold:
                all_predictions[example["id"]] = ""
            else:
                all_predictions[example["id"]] = best_non_null_pred["text"]

        # Make `predictions` JSON-serializable by casting np.float back to float.
        all_nbest_json[example["id"]] = [
            {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()}
            for pred in predictions
        ]

    # If we have an output_dir, let's save all those dicts.
    if output_dir is not None:
        if not os.path.isdir(output_dir):
            raise EnvironmentError(f"{output_dir} is not a directory.")

        prediction_file = os.path.join(
            output_dir, "predictions.json" if prefix is None else f"{prefix}_predictions.json"
        )
        nbest_file = os.path.join(
            output_dir, "nbest_predictions.json" if prefix is None else f"{prefix}_nbest_predictions.json"
        )
        if version_2_with_negative:
            null_odds_file = os.path.join(
                output_dir, "null_odds.json" if prefix is None else f"{prefix}_null_odds.json"
            )

        logger.info(f"Saving predictions to {prediction_file}.")
        with open(prediction_file, "w") as writer:
            writer.write(json.dumps(all_predictions, indent=4) + "\n")
        logger.info(f"Saving nbest_preds to {nbest_file}.")
        with open(nbest_file, "w") as writer:
            writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
        if version_2_with_negative:
            logger.info(f"Saving null_odds to {null_odds_file}.")
            with open(null_odds_file, "w") as writer:
                writer.write(json.dumps(scores_diff_json, indent=4) + "\n")

    return all_predictions

In [78]:
def post_processing_function(examples, features, predictions, stage="eval"):
        # Post-processing: we match the start logits and end logits to answers in the original context.
        predictions = postprocess_qa_predictions(
            examples=examples,
            features=features,
            predictions=predictions,
            max_answer_length=max_length
        )
        # Format the result to the format the metric expects.
        if 1==1:
            formatted_predictions = [
                {"id": str(k), "prediction_text": v, "no_answer_probability": 0.0} for k, v in predictions.items()
            ]
        else:
            formatted_predictions = [{"id": str(k), "prediction_text": v} for k, v in predictions.items()]

        references = [{"id": str(ex["id"]), "answers": ex["answers"]} for ex in examples]
        return EvalPrediction(predictions=formatted_predictions, label_ids=references)

In [86]:
examples = squad_raw['validation']
val_target_size = 3000
val_indices = random.sample(range(len(examples)), val_target_size)
examples = examples.select(val_indices)

In [87]:
features = []
for ex_idx in range(len(examples)):
    ex = examples[ex_idx]
    ex_id = ex["id"]
    context = ex["context"]
    answer_text = ex["answers"]["text"]
    answer_start = ex["answers"]["answer_start"]
    if answer_text and answer_start:
        offset = [answer_start[0], answer_start[0] + len(answer_text[0])]
        features.append({
            "example_id": ex_id,
            "offset_mapping": [offset]
        })

# Generate dummy start and end logits predictions
dummy_start_logits = np.random.rand(len(features), 4)
dummy_end_logits = np.random.rand(len(features), 4)
dummy_predictions = (dummy_start_logits, dummy_end_logits)

# Call the postprocess_qa_predictions function
predictions = postprocess_qa_predictions(
    examples=examples,
    features=features,
    predictions=dummy_predictions,
    version_2_with_negative=True,
    n_best_size=1,
    max_answer_length=30,
    null_score_diff_threshold=0.0,
    output_dir=None,
    prefix=None,
    log_level=logging.INFO
)

# Inspect the predictions
for ex_idx in range(len(examples)):
    ex = examples[ex_idx]
    ex_id = ex["id"]
    if ex["answers"]['text']:
        prediction = predictions.get(ex_id, "")
        print(f"Context: {ex['context']}")
        print(f"Question: {ex['question']}")
        print(f"Predicted Answer: {prediction}")
        print("Ground Truth Answers: ", ex["answers"]["text"][0])
        print("-" * 50)

  0%|          | 0/3000 [00:00<?, ?it/s]

Context: Broadly speaking, Daylight Saving Time was abandoned in the years after the war (with some notable exceptions including Canada, the UK, France, and Ireland for example). However, it was brought back for periods of time in many different places during the following decades, and commonly during the Second World War. It became widely adopted, particularly in North America and Europe starting in the 1970s as a result of the 1970s energy crisis.
Question: What country joined Canada, the UK, and Ireland in continuing to observe Daylight Saving Time after the war?
Predicted Answer: 
Ground Truth Answers:  Ireland
--------------------------------------------------
Context: Society throughout Europe was disturbed by the dislocations caused by the Black Death. Lands that had been marginally productive were abandoned, as the survivors were able to acquire more fertile areas. Although serfdom declined in Western Europe it became more common in Eastern Europe, as landlords imposed it on th

In [90]:
# Test edge case for examples with no answers
example_with_no_answer = None
for ex in examples:
    if not ex["answers"]["text"]:
        example_with_no_answer = ex
        break

print(example_with_no_answer)

{'id': '5a7cca31e8bc7e001a9e2024', 'title': 'Thuringia', 'context': 'Of the approximately 850 municipalities of Thuringia, 126 are classed as towns (within a district) or cities (forming their own urban district). Most of the towns are small with a population of less than 10,000; only the ten biggest ones have a population greater than 30,000. The first towns emerged during the 12th century, whereas the latest ones received town status only in the 20th century. Today, all municipalities within districts are equal in law, whether they are towns or villages. Independent cities (i.e. urban districts) have greater powers (the same as any district) than towns within a district.', 'question': 'How many municipalities in Thuringia are classified as hostile?', 'answers': {'text': [], 'answer_start': []}}


In [105]:
if example_with_no_answer:
    ex_id = example_with_no_answer["id"]
    features_with_no_answer = [{"example_id": ex_id, "offset_mapping": [[0, 10], [10, 20], [20, 30], [30, 40]]}]
    dummy_start_logits_no_answer = np.array([0.0])
    dummy_end_logits_no_answer = np.array([0.0])
    dummy_predictions_no_answer = (dummy_start_logits_no_answer, dummy_end_logits_no_answer)

    predictions_no_answer = postprocess_qa_predictions(
        examples=[example_with_no_answer],
        features=features_with_no_answer,
        predictions=dummy_predictions_no_answer,
        version_2_with_negative=True,
        n_best_size=1,
        max_answer_length=30,
        null_score_diff_threshold=0.0,
        output_dir=None,
        prefix=None,
        log_level=logging.INFO
    )

    print(f"Prediction for example with no answer: {predictions_no_answer.get(ex_id, '')}")

{'example_id': '5a7cca31e8bc7e001a9e2024', 'offset_mapping': [[0, 10], [10, 20], [20, 30], [30, 40]]}


KeyError: '5a7cca31e8bc7e001a9e2024'

## BERT

### Optuna Optimisation

In [None]:
torch.cuda.empty_cache()

In [None]:
model = AutoModelForQuestionAnswering.from_pretrained(pretrained_model_name)
training_args = TrainingArguments(
    output_dir=f'{pretrained_model_name}-finetuned-manual',
    overwrite_output_dir = True,
    metric_for_best_model='f1',
    greater_is_better=True,
    load_best_model_at_end=True,
    save_total_limit=4, 
    eval_strategy="epoch",
    save_strategy="epoch",
    report_to="wandb",  # Enable logging to Weights & Biases
    run_name=f"{pretrained_model_name}-finetune-manual",  # Optionally set a specific run name    
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
)

trainer = QuestionAnsweringTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        eval_examples=eval_examples,
        tokenizer=tokenizer,
        data_collator=data_collator,
        post_process_function=post_processing_function,
        compute_metrics=compute_metrics,
        callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train()