## Training

In [1]:
import torch
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq

In [2]:
# Load datasets
train_dataset = load_dataset('json', data_files='./data/train_TLQA.json')['train']
val_dataset = load_dataset('json', data_files='./data/val_TLQA.json')['train']

In [3]:
# Debug: Verify the structure of loaded data
for example in train_dataset['answers'][:5]:
    print("Example answers:", example)

Example answers: ['Carlson (2010, 2011, 2012, 2013, 2014, 2015, 2016)', 'HNA Group (2016, 2017, 2018)', 'Jinjiang International (2018, 2019, 2020)']
Example answers: ['Church Farm School (2014, 2015)', 'Davidson College (2016, 2017, 2018, 2019, 2020)']
Example answers: ['Secretary of State for Business, Energy and Industrial Strategy (2015, 2016)', 'Secretary of State for Housing, Communities and Local Government (2016, 2017, 2018)', 'Home Secretary (2018, 2019)', 'Chancellor of the Exchequer (2019, 2020)']
Example answers: ['Democratic Liberal Party (2010, 2011, 2012, 2013, 2014)', "People's Movement Party (2014)", 'independent politician (2014, 2015, 2016, 2017, 2018)', 'Liberty, Unity and Solidarity Party (2019, 2020)']
Example answers: ['Pat Fenlon (2011, 2012, 2013)', 'Jack Ross (2019, 2020)', 'Paul Heckingbottom (2019)']


In [4]:
# Load model and tokenizer
tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')

In [5]:
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question
    return model_inputs

tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)

In [6]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [7]:
# Training arguments
training_args = TrainingArguments(
    output_dir='./results-Bart-large', 
    eval_strategy='epoch', 
    per_device_train_batch_size=4, 
    per_device_eval_batch_size=4, 
    num_train_epochs=3, 
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
    save_steps=500,
    save_total_limit=2,
    fp16=True  # Enable mixed precision training
)

In [8]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    data_collator=data_collator,
)

In [9]:
# Train the model
trainer.train()

  0%|          | 0/2409 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 15.9486, 'grad_norm': 27.173036575317383, 'learning_rate': 4.987546699875467e-05, 'epoch': 0.01}
{'loss': 10.9457, 'grad_norm': 28.17203712463379, 'learning_rate': 4.966791199667912e-05, 'epoch': 0.02}
{'loss': 8.5466, 'grad_norm': 39.82261276245117, 'learning_rate': 4.9481112494811124e-05, 'epoch': 0.04}
{'loss': 6.3937, 'grad_norm': inf, 'learning_rate': 4.9294312992943134e-05, 'epoch': 0.05}
{'loss': 5.1197, 'grad_norm': 37.79838180541992, 'learning_rate': 4.908675799086758e-05, 'epoch': 0.06}
{'loss': 3.9821, 'grad_norm': 39.6391487121582, 'learning_rate': 4.887920298879203e-05, 'epoch': 0.07}
{'loss': 2.8727, 'grad_norm': 37.628902435302734, 'learning_rate': 4.8671647986716485e-05, 'epoch': 0.09}
{'loss': 1.838, 'grad_norm': 31.730440139770508, 'learning_rate': 4.846409298464093e-05, 'epoch': 0.1}
{'loss': 0.931, 'grad_norm': 16.828102111816406, 'learning_rate': 4.8256537982565384e-05, 'epoch': 0.11}
{'loss': 0.4697, 'grad_norm': 6.782123565673828, 'learning_rate': 4.8048

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.1361, 'grad_norm': 0.9829115271568298, 'learning_rate': 3.974678289746783e-05, 'epoch': 0.62}
{'loss': 0.1158, 'grad_norm': 0.9494917988777161, 'learning_rate': 3.9539227895392284e-05, 'epoch': 0.64}
{'loss': 0.1294, 'grad_norm': 1.0087183713912964, 'learning_rate': 3.933167289331673e-05, 'epoch': 0.65}
{'loss': 0.0961, 'grad_norm': 0.7435879111289978, 'learning_rate': 3.912411789124118e-05, 'epoch': 0.66}
{'loss': 0.1137, 'grad_norm': 0.6297553777694702, 'learning_rate': 3.891656288916563e-05, 'epoch': 0.67}
{'loss': 0.0928, 'grad_norm': 0.47065243124961853, 'learning_rate': 3.870900788709008e-05, 'epoch': 0.68}
{'loss': 0.1031, 'grad_norm': 0.7489163875579834, 'learning_rate': 3.8501452885014534e-05, 'epoch': 0.7}
{'loss': 0.1035, 'grad_norm': 0.5788906216621399, 'learning_rate': 3.829389788293898e-05, 'epoch': 0.71}
{'loss': 0.135, 'grad_norm': 0.904438853263855, 'learning_rate': 3.8086342880863426e-05, 'epoch': 0.72}
{'loss': 0.0982, 'grad_norm': 0.8447710871696472, 'lea

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.09586504846811295, 'eval_runtime': 771.0661, 'eval_samples_per_second': 1.852, 'eval_steps_per_second': 0.463, 'epoch': 1.0}
{'loss': 0.1084, 'grad_norm': 0.6260566711425781, 'learning_rate': 3.331257783312578e-05, 'epoch': 1.01}
{'loss': 0.0964, 'grad_norm': 0.8623329997062683, 'learning_rate': 3.310502283105023e-05, 'epoch': 1.02}
{'loss': 0.0988, 'grad_norm': 1.6388957500457764, 'learning_rate': 3.2897467828974684e-05, 'epoch': 1.03}
{'loss': 0.0917, 'grad_norm': 0.6393274068832397, 'learning_rate': 3.268991282689913e-05, 'epoch': 1.05}
{'loss': 0.1006, 'grad_norm': 0.5739547610282898, 'learning_rate': 3.2482357824823576e-05, 'epoch': 1.06}
{'loss': 0.0868, 'grad_norm': 0.9895109534263611, 'learning_rate': 3.227480282274803e-05, 'epoch': 1.07}
{'loss': 0.0862, 'grad_norm': 0.7220048904418945, 'learning_rate': 3.206724782067248e-05, 'epoch': 1.08}
{'loss': 0.0991, 'grad_norm': 0.34455621242523193, 'learning_rate': 3.1859692818596934e-05, 'epoch': 1.1}
{'loss': 0.0845,

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.0965, 'grad_norm': 1.1866178512573242, 'learning_rate': 2.9369032793690333e-05, 'epoch': 1.25}
{'loss': 0.097, 'grad_norm': 0.6203020215034485, 'learning_rate': 2.916147779161478e-05, 'epoch': 1.26}
{'loss': 0.1097, 'grad_norm': 0.6265918016433716, 'learning_rate': 2.8953922789539228e-05, 'epoch': 1.27}
{'loss': 0.1054, 'grad_norm': 1.3524749279022217, 'learning_rate': 2.874636778746368e-05, 'epoch': 1.28}
{'loss': 0.0977, 'grad_norm': 0.4561586380004883, 'learning_rate': 2.853881278538813e-05, 'epoch': 1.3}
{'loss': 0.0829, 'grad_norm': 0.794069766998291, 'learning_rate': 2.8331257783312583e-05, 'epoch': 1.31}
{'loss': 0.0783, 'grad_norm': 0.6238290071487427, 'learning_rate': 2.8123702781237026e-05, 'epoch': 1.32}
{'loss': 0.0759, 'grad_norm': 0.6510286331176758, 'learning_rate': 2.791614777916148e-05, 'epoch': 1.33}
{'loss': 0.0986, 'grad_norm': 1.0802148580551147, 'learning_rate': 2.7708592777085928e-05, 'epoch': 1.34}
{'loss': 0.0789, 'grad_norm': 0.3864585757255554, 'le

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.0917, 'grad_norm': 0.5561153292655945, 'learning_rate': 1.901203819012038e-05, 'epoch': 1.87}
{'loss': 0.0794, 'grad_norm': 0.7868625521659851, 'learning_rate': 1.8804483188044834e-05, 'epoch': 1.88}
{'loss': 0.078, 'grad_norm': 0.7516345381736755, 'learning_rate': 1.8596928185969283e-05, 'epoch': 1.89}
{'loss': 0.0824, 'grad_norm': 0.4143904745578766, 'learning_rate': 1.8389373183893733e-05, 'epoch': 1.91}
{'loss': 0.0812, 'grad_norm': 0.39110463857650757, 'learning_rate': 1.8181818181818182e-05, 'epoch': 1.92}
{'loss': 0.0914, 'grad_norm': 0.5670859217643738, 'learning_rate': 1.797426317974263e-05, 'epoch': 1.93}
{'loss': 0.0731, 'grad_norm': 0.3663381040096283, 'learning_rate': 1.776670817766708e-05, 'epoch': 1.94}
{'loss': 0.1177, 'grad_norm': 0.47243449091911316, 'learning_rate': 1.7559153175591534e-05, 'epoch': 1.96}
{'loss': 0.0726, 'grad_norm': 0.6965134143829346, 'learning_rate': 1.7351598173515983e-05, 'epoch': 1.97}
{'loss': 0.1013, 'grad_norm': 0.4686548709869385

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.08962031453847885, 'eval_runtime': 772.5122, 'eval_samples_per_second': 1.849, 'eval_steps_per_second': 0.462, 'epoch': 2.0}
{'loss': 0.0708, 'grad_norm': 1.1819771528244019, 'learning_rate': 1.672893316728933e-05, 'epoch': 2.0}
{'loss': 0.0624, 'grad_norm': 0.32025691866874695, 'learning_rate': 1.6521378165213784e-05, 'epoch': 2.02}
{'loss': 0.0857, 'grad_norm': 0.4836205542087555, 'learning_rate': 1.6313823163138233e-05, 'epoch': 2.03}
{'loss': 0.0797, 'grad_norm': 0.840100884437561, 'learning_rate': 1.6106268161062682e-05, 'epoch': 2.04}
{'loss': 0.0739, 'grad_norm': 0.539995014667511, 'learning_rate': 1.5898713158987132e-05, 'epoch': 2.05}
{'loss': 0.0792, 'grad_norm': 0.580274224281311, 'learning_rate': 1.569115815691158e-05, 'epoch': 2.07}
{'loss': 0.0595, 'grad_norm': 0.5334646105766296, 'learning_rate': 1.548360315483603e-05, 'epoch': 2.08}
{'loss': 0.0675, 'grad_norm': 0.4350792169570923, 'learning_rate': 1.5276048152760483e-05, 'epoch': 2.09}
{'loss': 0.0751, 

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


{'loss': 0.058, 'grad_norm': 0.5126574635505676, 'learning_rate': 8.634288086342881e-06, 'epoch': 2.49}
{'loss': 0.0621, 'grad_norm': 0.31712618470191956, 'learning_rate': 8.42673308426733e-06, 'epoch': 2.5}
{'loss': 0.07, 'grad_norm': 0.7903696298599243, 'learning_rate': 8.21917808219178e-06, 'epoch': 2.52}
{'loss': 0.0615, 'grad_norm': 0.3125416040420532, 'learning_rate': 8.011623080116231e-06, 'epoch': 2.53}
{'loss': 0.0533, 'grad_norm': 0.4105648100376129, 'learning_rate': 7.804068078040682e-06, 'epoch': 2.54}
{'loss': 0.074, 'grad_norm': 0.37184369564056396, 'learning_rate': 7.596513075965131e-06, 'epoch': 2.55}
{'loss': 0.0691, 'grad_norm': 0.466719388961792, 'learning_rate': 7.388958073889581e-06, 'epoch': 2.57}
{'loss': 0.0724, 'grad_norm': 0.5285528898239136, 'learning_rate': 7.181403071814032e-06, 'epoch': 2.58}
{'loss': 0.0589, 'grad_norm': 0.4906187355518341, 'learning_rate': 6.973848069738481e-06, 'epoch': 2.59}
{'loss': 0.0592, 'grad_norm': 0.4849088490009308, 'learning_r

  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.08577268570661545, 'eval_runtime': 1172.4621, 'eval_samples_per_second': 1.218, 'eval_steps_per_second': 0.304, 'epoch': 3.0}
{'train_runtime': 27495.9821, 'train_samples_per_second': 0.35, 'train_steps_per_second': 0.088, 'train_loss': 0.32656063143768566, 'epoch': 3.0}


TrainOutput(global_step=2409, training_loss=0.32656063143768566, metrics={'train_runtime': 27495.9821, 'train_samples_per_second': 0.35, 'train_steps_per_second': 0.088, 'total_flos': 1.0441109972975616e+16, 'train_loss': 0.32656063143768566, 'epoch': 3.0})

In [10]:
# Save the model
model.save_pretrained('./results-Bart-large')
tokenizer.save_pretrained('./results-Bart-large')

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


('./results-Bart-large\\tokenizer_config.json',
 './results-Bart-large\\special_tokens_map.json',
 './results-Bart-large\\vocab.json',
 './results-Bart-large\\merges.txt',
 './results-Bart-large\\added_tokens.json')

## Prediction

In [11]:
from datasets import load_dataset
from transformers import BartTokenizer, BartForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import json

In [12]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [13]:
# Tokenize the test dataset
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/1071 [00:00<?, ? examples/s]



In [14]:
# Perform inference on test set
def generate_answers(batch):
    inputs = batch['input_ids'].to(model.device)
    outputs = model.generate(
        inputs,
        max_length=512, 
        num_beams=5, 
        early_stopping=True, 
        repetition_penalty=2.5, 
        length_penalty=1.0,
        no_repeat_ngram_size=2,
        num_return_sequences=1  # ensure alignment with targets 
    )
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output.split("; ") for output in decoded_outputs]


In [15]:
predictions = []
for batch in trainer.get_test_dataloader(tokenized_test):
    batch_predictions = generate_answers(batch)
    predictions.extend(batch_predictions)


In [16]:
# Save predictions to a JSON file
with open('predictions-Bart-large.json', 'w') as f:
    json.dump(predictions, f)

## Evaluation

In [1]:
from datasets import load_dataset
import json

In [12]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
import re

# Define metric functions
def entity_match(predictions, references):
    scores = []
    for pred_list, ref_list in zip(predictions, references):
        pred_entities = set([re.sub(r'\s*\(.*?\)', '', pred.strip().lower()) for pred in pred_list])
        ref_entities = set([re.sub(r'\s*\(.*?\)', '', ref.strip().lower()) for ref in ref_list])

        # Debug prints
        # print("Pred Entities:", pred_entities)
        # print("Ref Entities:", ref_entities)

        matches = len(pred_entities & ref_entities)  # Intersection of sets
        scores.append(matches / len(ref_entities) if ref_entities else 0)
    return sum(scores) / len(scores)


def timeline_match(predictions, references):
    def extract_years(text):
        matches = re.findall(r'\d{4}', text)
        if matches:
            return set(map(int, matches))
        return set()

    scores = []
    for pred_list, ref_list in zip(predictions, references):
        match_scores = []
        for i, ref in enumerate(ref_list):
            ref_years = extract_years(ref)
            try:
                pred_years = extract_years(pred_list[i])
                # print(f"Matching prediction timeline {i + 1} with reference timeline {i + 1}")
                # print("Pred Years:", pred_years)
                # print("Ref Years:", ref_years)
                matches = len(pred_years & ref_years)
                match_scores.append(matches / len(ref_years) if ref_years else 0)
            except IndexError:
                # print(f"No prediction timeline for reference timeline {i + 1}")
                match_scores.append(0)
        scores.append(sum(match_scores) / len(match_scores) if match_scores else 0)
    return sum(scores) / len(scores)

def f1_metric(predictions, references):
    def compute_f1(pred_list, ref_list):
        pred_tokens = set(pred_list)
        ref_tokens = set(ref_list)

        if not ref_tokens:
            return 0.0

        true_positives = len(pred_tokens & ref_tokens)
        precision = true_positives / len(pred_tokens) if pred_tokens else 0
        recall = true_positives / len(ref_tokens) if ref_tokens else 0

        if precision + recall == 0:
            return 0.0

        return 2 * (precision * recall) / (precision + recall)

    f1_scores = []

    for pred_list, ref_list in zip(predictions, references):
        # Flatten the lists for token-level comparison
        all_preds = [token for pred in pred_list for token in pred.split()]
        all_refs = [token for ref in ref_list for token in ref.split()]
        f1_scores.append(compute_f1(all_preds, all_refs))

    return sum(f1_scores) / len(f1_scores) if f1_scores else 0

def extract_years(text):
    matches = re.findall(r'\d{4}', text)
    if matches:
        years = list(map(int, matches))
        return min(years), max(years)  # Return the earliest and latest years
    return None, None

def time_metric(predictions, references):
    time_diffs = []
    
    for pred_list, ref_list in zip(predictions, references):
        for pred, ref in zip(pred_list, ref_list):
            pred_start, pred_end = extract_years(pred)
            ref_start, ref_end = extract_years(ref)
            
            # print(f"Prediction: {pred}, Extracted Years: ({pred_start}, {pred_end})")  # Debugging
            # print(f"Reference: {ref}, Extracted Years: ({ref_start}, {ref_end})")  # Debugging
            
            if pred_start is not None and ref_start is not None:
                time_diff = abs(pred_start - ref_start) + abs(pred_end - ref_end)
                time_diffs.append(time_diff)
            else:
                time_diffs.append(float('inf'))  # Invalid prediction or reference

    valid_diffs = [diff for diff in time_diffs if diff != float('inf')]
    if not valid_diffs:
        return float('inf')  # If all are invalid, return inf
    return sum(valid_diffs) / len(valid_diffs) # Average time difference

def completeness(predictions, references):
    scores = []
    
    for pred_list, ref_list in zip(predictions, references):
        pred_items = set([item.strip().lower() for pred in pred_list for item in pred.split(", ")])
        ref_items = set([item.strip().lower() for ref in ref_list for item in ref.split(", ")])
        
        correct_count = len(pred_items.intersection(ref_items))
        total_count = len(ref_items)
        
        scores.append(correct_count / total_count if total_count > 0 else 0)
    return sum(scores) / len(scores)

In [3]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [4]:
# References
references = [ans for ans in test_dataset['answers']]  # Ensure references are lists of answers

In [5]:
# Load predictions from a JSON file
with open('predictions-Bart-large.json', 'r') as f:
    predictions = json.load(f)

In [6]:
# Debug: Print some predictions and references to ensure alignment
for pred, ref in zip(predictions, references):
    print("Prediction:", pred)
    print("Reference:", ref)

Prediction: ['England national association football team (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'Queens Park Rangers F.C. (2011)']
Reference: ['Southend United F.C. (2010, 2011, 2012)', 'Stevenage F.C. (2012, 2013)', 'Crewe Alexandra F.C. (2013, 2014, 2015)', 'Port Vale F.C. (2015, 2016, 2017, 2018, 2019, 2020)']
Prediction: ['President of Ukraine (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019)', 'Chairman of the State Duma (2019, 2020)']
Reference: ['Prime Minister of Ukraine (2010)', 'First Deputy Prime Minister of Ukraine (2010)', "People's Deputy of Ukraine (2012, 2013, 2014, 2015)", 'Chairman of the Verkhovna Rada (2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'President of Ukraine (2014)']
Prediction: ['Democratic Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019)', 'Sinn Féin (2019, 2020)']
Reference: ['Socialist Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019)', 'RISE (Ireland) (2019, 2020)']
Prediction:

In [13]:
# Evaluate predictions with custom metrics
entity_score = entity_match(predictions, references)
timeline_score = timeline_match(predictions, references)
f1 = f1_metric(predictions, references)
time_metric_score = time_metric(predictions, references)
completeness_score = completeness(predictions, references)

print(f"Entity Match: {entity_score}")
print(f"Timeline Match: {timeline_score}")
print(f"F1 Score: {f1}")
print(f"Time Metric: {time_metric_score}")
print(f"Completeness: {completeness_score}")

Entity Match: 0.08566537726201592
Timeline Match: 0.5456470870960825
F1 Score: 0.5082865438538069
Time Metric: 4.616387337057728
Completeness: 0.5556493441627903
