## Training

In [33]:
import torch
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq

In [34]:
# Load datasets
train_dataset = load_dataset('json', data_files='./data/train_TLQA.json')['train']
val_dataset = load_dataset('json', data_files='./data/val_TLQA.json')['train']

In [35]:
# Debug: Verify the structure of loaded data
for example in train_dataset['answers'][:5]:
    print("Example answers:", example)

Example answers: ['Carlson (2010, 2011, 2012, 2013, 2014, 2015, 2016)', 'HNA Group (2016, 2017, 2018)', 'Jinjiang International (2018, 2019, 2020)']
Example answers: ['Church Farm School (2014, 2015)', 'Davidson College (2016, 2017, 2018, 2019, 2020)']
Example answers: ['Secretary of State for Business, Energy and Industrial Strategy (2015, 2016)', 'Secretary of State for Housing, Communities and Local Government (2016, 2017, 2018)', 'Home Secretary (2018, 2019)', 'Chancellor of the Exchequer (2019, 2020)']
Example answers: ['Democratic Liberal Party (2010, 2011, 2012, 2013, 2014)', "People's Movement Party (2014)", 'independent politician (2014, 2015, 2016, 2017, 2018)', 'Liberty, Unity and Solidarity Party (2019, 2020)']
Example answers: ['Pat Fenlon (2011, 2012, 2013)', 'Jack Ross (2019, 2020)', 'Paul Heckingbottom (2019)']


In [36]:
# Load model and tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-base')

In [37]:
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question
    return model_inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)

In [38]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [39]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./results-Bart-base', 
    eval_strategy='epoch', 
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=4, 
    num_train_epochs=3, 
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    fp16=True  # Enable mixed precision training
)

In [40]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

In [41]:
# Train the model
trainer.train()

  0%|          | 0/2409 [00:00<?, ?it/s]

{'loss': 14.7735, 'grad_norm': 26.16888427734375, 'learning_rate': 4.9896222498962225e-05, 'epoch': 0.01}
{'loss': 9.2197, 'grad_norm': 50.38726043701172, 'learning_rate': 4.968866749688668e-05, 'epoch': 0.02}
{'loss': 5.3204, 'grad_norm': 50.44267654418945, 'learning_rate': 4.9481112494811124e-05, 'epoch': 0.04}
{'loss': 3.7672, 'grad_norm': 48.95286560058594, 'learning_rate': 4.927355749273558e-05, 'epoch': 0.05}
{'loss': 2.6651, 'grad_norm': 44.98599624633789, 'learning_rate': 4.906600249066003e-05, 'epoch': 0.06}
{'loss': 1.6598, 'grad_norm': 34.98147201538086, 'learning_rate': 4.8858447488584476e-05, 'epoch': 0.07}
{'loss': 0.9411, 'grad_norm': 19.710186004638672, 'learning_rate': 4.865089248650893e-05, 'epoch': 0.09}
{'loss': 0.4781, 'grad_norm': 8.235982894897461, 'learning_rate': 4.8443337484433374e-05, 'epoch': 0.1}
{'loss': 0.2431, 'grad_norm': 2.6681935787200928, 'learning_rate': 4.823578248235783e-05, 'epoch': 0.11}
{'loss': 0.1768, 'grad_norm': 1.1386349201202393, 'learnin

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.1433, 'grad_norm': 0.5068555474281311, 'learning_rate': 3.9726027397260274e-05, 'epoch': 0.62}
{'loss': 0.1223, 'grad_norm': 0.5291364789009094, 'learning_rate': 3.951847239518473e-05, 'epoch': 0.64}
{'loss': 0.1199, 'grad_norm': 0.5624876618385315, 'learning_rate': 3.931091739310917e-05, 'epoch': 0.65}
{'loss': 0.1035, 'grad_norm': 0.8106922507286072, 'learning_rate': 3.9103362391033626e-05, 'epoch': 0.66}
{'loss': 0.1215, 'grad_norm': 0.5476059317588806, 'learning_rate': 3.889580738895808e-05, 'epoch': 0.67}
{'loss': 0.1008, 'grad_norm': 0.42654210329055786, 'learning_rate': 3.8688252386882525e-05, 'epoch': 0.68}
{'loss': 0.1054, 'grad_norm': 0.31270891427993774, 'learning_rate': 3.848069738480698e-05, 'epoch': 0.7}
{'loss': 0.1076, 'grad_norm': 0.44748976826667786, 'learning_rate': 3.827314238273142e-05, 'epoch': 0.71}
{'loss': 0.1419, 'grad_norm': 0.6891569495201111, 'learning_rate': 3.8065587380655876e-05, 'epoch': 0.72}
{'loss': 0.1054, 'grad_norm': 0.5200108289718628,

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.09756193310022354, 'eval_runtime': 25.3283, 'eval_samples_per_second': 56.38, 'eval_steps_per_second': 14.095, 'epoch': 1.0}
{'loss': 0.1135, 'grad_norm': 0.5673679709434509, 'learning_rate': 3.329182233291823e-05, 'epoch': 1.01}
{'loss': 0.1041, 'grad_norm': 0.6700311303138733, 'learning_rate': 3.308426733084268e-05, 'epoch': 1.02}
{'loss': 0.1151, 'grad_norm': 0.48665374517440796, 'learning_rate': 3.287671232876712e-05, 'epoch': 1.03}
{'loss': 0.1016, 'grad_norm': 0.6035619378089905, 'learning_rate': 3.266915732669157e-05, 'epoch': 1.05}
{'loss': 0.1106, 'grad_norm': 0.55247563123703, 'learning_rate': 3.2461602324616026e-05, 'epoch': 1.06}
{'loss': 0.0925, 'grad_norm': 0.4972968101501465, 'learning_rate': 3.225404732254048e-05, 'epoch': 1.07}
{'loss': 0.092, 'grad_norm': 0.40038466453552246, 'learning_rate': 3.2046492320464924e-05, 'epoch': 1.08}
{'loss': 0.1103, 'grad_norm': 0.3486107587814331, 'learning_rate': 3.183893731838937e-05, 'epoch': 1.1}
{'loss': 0.0923, 'g

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.1092, 'grad_norm': 0.5253717303276062, 'learning_rate': 2.9348277293482773e-05, 'epoch': 1.25}
{'loss': 0.0891, 'grad_norm': 0.582353949546814, 'learning_rate': 2.9140722291407226e-05, 'epoch': 1.26}
{'loss': 0.1239, 'grad_norm': 0.5782084465026855, 'learning_rate': 2.8933167289331675e-05, 'epoch': 1.27}
{'loss': 0.1146, 'grad_norm': 0.704578161239624, 'learning_rate': 2.8725612287256128e-05, 'epoch': 1.28}
{'loss': 0.1074, 'grad_norm': 0.4141498804092407, 'learning_rate': 2.851805728518057e-05, 'epoch': 1.3}
{'loss': 0.0935, 'grad_norm': 0.46054351329803467, 'learning_rate': 2.8310502283105023e-05, 'epoch': 1.31}
{'loss': 0.0856, 'grad_norm': 0.5290907621383667, 'learning_rate': 2.8102947281029472e-05, 'epoch': 1.32}
{'loss': 0.0817, 'grad_norm': 0.3894304037094116, 'learning_rate': 2.7895392278953925e-05, 'epoch': 1.33}
{'loss': 0.1055, 'grad_norm': 0.7035382986068726, 'learning_rate': 2.7687837276878374e-05, 'epoch': 1.34}
{'loss': 0.0862, 'grad_norm': 0.44194528460502625

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.107, 'grad_norm': 0.32365885376930237, 'learning_rate': 1.897052718970527e-05, 'epoch': 1.87}
{'loss': 0.0887, 'grad_norm': 0.3940908908843994, 'learning_rate': 1.8762972187629724e-05, 'epoch': 1.88}
{'loss': 0.0888, 'grad_norm': 0.5864721536636353, 'learning_rate': 1.8555417185554173e-05, 'epoch': 1.89}
{'loss': 0.0892, 'grad_norm': 0.34468117356300354, 'learning_rate': 1.8347862183478623e-05, 'epoch': 1.91}
{'loss': 0.0904, 'grad_norm': 0.3960183262825012, 'learning_rate': 1.8140307181403075e-05, 'epoch': 1.92}
{'loss': 0.0973, 'grad_norm': 0.3186906576156616, 'learning_rate': 1.793275217932752e-05, 'epoch': 1.93}
{'loss': 0.0818, 'grad_norm': 0.44049227237701416, 'learning_rate': 1.7725197177251974e-05, 'epoch': 1.94}
{'loss': 0.1302, 'grad_norm': 0.41678982973098755, 'learning_rate': 1.751764217517642e-05, 'epoch': 1.96}
{'loss': 0.0854, 'grad_norm': 0.4122941792011261, 'learning_rate': 1.7310087173100873e-05, 'epoch': 1.97}
{'loss': 0.1133, 'grad_norm': 0.40840247273445

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.09274683147668839, 'eval_runtime': 25.9307, 'eval_samples_per_second': 55.07, 'eval_steps_per_second': 13.767, 'epoch': 2.0}
{'loss': 0.0841, 'grad_norm': 0.3441394865512848, 'learning_rate': 1.6687422166874224e-05, 'epoch': 2.0}
{'loss': 0.0798, 'grad_norm': 0.25682705640792847, 'learning_rate': 1.6479867164798674e-05, 'epoch': 2.02}
{'loss': 0.1006, 'grad_norm': 0.4381483197212219, 'learning_rate': 1.6272312162723123e-05, 'epoch': 2.03}
{'loss': 0.0935, 'grad_norm': 0.7371136546134949, 'learning_rate': 1.6064757160647572e-05, 'epoch': 2.04}
{'loss': 0.0861, 'grad_norm': 0.461798757314682, 'learning_rate': 1.585720215857202e-05, 'epoch': 2.05}
{'loss': 0.0949, 'grad_norm': 0.5524975657463074, 'learning_rate': 1.564964715649647e-05, 'epoch': 2.07}
{'loss': 0.0754, 'grad_norm': 0.5157014727592468, 'learning_rate': 1.5442092154420924e-05, 'epoch': 2.08}
{'loss': 0.0825, 'grad_norm': 0.40135130286216736, 'learning_rate': 1.5234537152345371e-05, 'epoch': 2.09}
{'loss': 0.09

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.0736, 'grad_norm': 0.493009477853775, 'learning_rate': 8.592777085927771e-06, 'epoch': 2.49}
{'loss': 0.0779, 'grad_norm': 0.3048422336578369, 'learning_rate': 8.385222083852222e-06, 'epoch': 2.5}
{'loss': 0.0844, 'grad_norm': 0.5319421291351318, 'learning_rate': 8.177667081776672e-06, 'epoch': 2.52}
{'loss': 0.0753, 'grad_norm': 0.28069958090782166, 'learning_rate': 7.970112079701121e-06, 'epoch': 2.53}
{'loss': 0.0671, 'grad_norm': 0.47599202394485474, 'learning_rate': 7.76255707762557e-06, 'epoch': 2.54}
{'loss': 0.0906, 'grad_norm': 0.38459405303001404, 'learning_rate': 7.555002075550021e-06, 'epoch': 2.55}
{'loss': 0.079, 'grad_norm': 0.26903247833251953, 'learning_rate': 7.3474470734744716e-06, 'epoch': 2.57}
{'loss': 0.0925, 'grad_norm': 0.3840484619140625, 'learning_rate': 7.139892071398921e-06, 'epoch': 2.58}
{'loss': 0.0758, 'grad_norm': 0.40407612919807434, 'learning_rate': 6.93233706932337e-06, 'epoch': 2.59}
{'loss': 0.0693, 'grad_norm': 0.33008506894111633, 'le

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.0916919857263565, 'eval_runtime': 26.1837, 'eval_samples_per_second': 54.538, 'eval_steps_per_second': 13.634, 'epoch': 3.0}
{'train_runtime': 873.0956, 'train_samples_per_second': 11.037, 'train_steps_per_second': 2.759, 'train_loss': 0.25920408917146387, 'epoch': 3.0}


TrainOutput(global_step=2409, training_loss=0.25920408917146387, metrics={'train_runtime': 873.0956, 'train_samples_per_second': 11.037, 'train_steps_per_second': 2.759, 'total_flos': 2937710255800320.0, 'train_loss': 0.25920408917146387, 'epoch': 3.0})

In [42]:
# Save the model
model.save_pretrained('./results-Bart-base')
tokenizer.save_pretrained('./results-Bart-base')

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


('./results-Bart-base\\tokenizer_config.json',
 './results-Bart-base\\special_tokens_map.json',
 './results-Bart-base\\vocab.json',
 './results-Bart-base\\merges.txt',
 './results-Bart-base\\added_tokens.json')

## Prediction

In [43]:
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import json

In [45]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [47]:
# Tokenize the test dataset
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/1071 [00:00<?, ? examples/s]



In [48]:
# Perform inference on test set
def generate_answers(batch):
    inputs = batch['input_ids'].to(model.device)
    outputs = model.generate(
        inputs,
        max_length=512, 
        num_beams=5, 
        early_stopping=True, 
        repetition_penalty=2.5, 
        length_penalty=1.0,
        no_repeat_ngram_size=2,
        num_return_sequences=1  # ensure alignment with targets 
    )
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output.split("; ") for output in decoded_outputs]


In [49]:
predictions = []
for batch in trainer.get_test_dataloader(tokenized_test):
    batch_predictions = generate_answers(batch)
    predictions.extend(batch_predictions)


In [50]:
# Save predictions to a JSON file
with open('predictions-Bart-base.json', 'w') as f:
    json.dump(predictions, f)

## Evaluation

In [18]:
from datasets import load_dataset
import json

In [51]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
import re

# Define metric functions
def exact_match(predictions, references):
    scores = []
    for pred_list, ref_list in zip(predictions, references):
        # Convert to lower case and strip spaces for comparison
        pred_set = set([pred.strip().lower() for pred in pred_list])
        ref_set = set([ref.strip().lower() for ref in ref_list])
        exact_matches = len(pred_set & ref_set) # Intersection of sets
        scores.append(exact_matches / len(ref_set) if ref_set else 0) 
    return sum(scores) / len(scores) 

def f1_metric(predictions, references):
    def compute_f1(pred_list, ref_list):
        pred_tokens = set(pred_list)
        ref_tokens = set(ref_list)

        if not ref_tokens:
            return 0.0

        true_positives = len(pred_tokens & ref_tokens)
        precision = true_positives / len(pred_tokens) if pred_tokens else 0
        recall = true_positives / len(ref_tokens) if ref_tokens else 0

        if precision + recall == 0:
            return 0.0

        return 2 * (precision * recall) / (precision + recall)

    f1_scores = []

    for pred_list, ref_list in zip(predictions, references):
        # Flatten the lists for token-level comparison
        all_preds = [token for pred in pred_list for token in pred.split()]
        all_refs = [token for ref in ref_list for token in ref.split()]
        f1_scores.append(compute_f1(all_preds, all_refs))

    return sum(f1_scores) / len(f1_scores) if f1_scores else 0

def extract_years(text):
    matches = re.findall(r'\d{4}', text)
    if matches:
        years = list(map(int, matches))
        return min(years), max(years)  # Return the earliest and latest years
    return None, None

def time_metric(predictions, references):
    time_diffs = []
    
    for pred_list, ref_list in zip(predictions, references):
        for pred, ref in zip(pred_list, ref_list):
            pred_start, pred_end = extract_years(pred)
            ref_start, ref_end = extract_years(ref)
            
            print(f"Prediction: {pred}, Extracted Years: ({pred_start}, {pred_end})")  # Debugging
            print(f"Reference: {ref}, Extracted Years: ({ref_start}, {ref_end})")  # Debugging
            
            if pred_start is not None and ref_start is not None:
                time_diff = abs(pred_start - ref_start) + abs(pred_end - ref_end)
                time_diffs.append(time_diff)
            else:
                time_diffs.append(float('inf'))  # Invalid prediction or reference

    valid_diffs = [diff for diff in time_diffs if diff != float('inf')]
    if not valid_diffs:
        return float('inf')  # If all are invalid, return inf
    return sum(valid_diffs) / len(valid_diffs) # Average time difference

def completeness(predictions, references):
    scores = []
    
    for pred_list, ref_list in zip(predictions, references):
        pred_items = set([item.strip().lower() for pred in pred_list for item in pred.split(", ")])
        ref_items = set([item.strip().lower() for ref in ref_list for item in ref.split(", ")])
        
        correct_count = len(pred_items.intersection(ref_items))
        total_count = len(ref_items)
        
        scores.append(correct_count / total_count if total_count > 0 else 0)
    return sum(scores) / len(scores)

In [29]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [30]:
# References
references = [ans for ans in test_dataset['answers']]  # Ensure references are lists of answers

In [31]:
# Load predictions from a JSON file
with open('predictions-Bart-base.json', 'r') as f:
    predictions = json.load(f)

In [52]:
# Debug: Print some predictions and references to ensure alignment
for pred, ref in zip(predictions, references):
    print("Prediction:", pred)
    print("Reference:", ref)

Prediction: ['Aston Villa F.C. (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'England national under-21 football team (2011)', 'Manchester City AFC (2012)', 'Bolton Wanderers FC (2013)']
Reference: ['Southend United F.C. (2010, 2011, 2012)', 'Stevenage F.C. (2012, 2013)', 'Crewe Alexandra F.C. (2013, 2014, 2015)', 'Port Vale F.C. (2015, 2016, 2017, 2018, 2019, 2020)']
Prediction: ['member of the European Parliament (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'Prime Minister of Ukraine (2020)']
Reference: ['Prime Minister of Ukraine (2010)', 'First Deputy Prime Minister of Ukraine (2010)', "People's Deputy of Ukraine (2012, 2013, 2014, 2015)", 'Chairman of the Verkhovna Rada (2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'President of Ukraine (2014)']
Prediction: ['Democratic Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017)', 'independent politician (2017, 2018, 2019, 2020)']
Reference: ['Socialist Party (2010, 2011, 2012, 2013, 2014, 

In [53]:
# Evaluate predictions with custom metrics
em_scores = [exact_match(pred, ref) for pred, ref in zip(predictions, references)]
em = sum(em_scores) / len(em_scores)

f1 = f1_metric(predictions, references)
time_metric_score = time_metric(predictions, references)
completeness_score = completeness(predictions, references)

print(f"Exact Match (EM): {em}")
print(f"F1 Score: {f1}")
print(f"Time Metric: {time_metric_score}")
print(f"Completeness: {completeness_score}")

Prediction: Aston Villa F.C. (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Southend United F.C. (2010, 2011, 2012), Extracted Years: (2010, 2012)
Prediction: England national under-21 football team (2011), Extracted Years: (2011, 2011)
Reference: Stevenage F.C. (2012, 2013), Extracted Years: (2012, 2013)
Prediction: Manchester City AFC (2012), Extracted Years: (2012, 2012)
Reference: Crewe Alexandra F.C. (2013, 2014, 2015), Extracted Years: (2013, 2015)
Prediction: Bolton Wanderers FC (2013), Extracted Years: (2013, 2013)
Reference: Port Vale F.C. (2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2015, 2020)
Prediction: member of the European Parliament (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Prime Minister of Ukraine (2010), Extracted Years: (2010, 2010)
Prediction: Prime Minister of Ukraine (2020), Extracted Years: (2020, 2020)
Reference: First Dep