## Training

In [1]:
import torch
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq

In [2]:
# Load datasets
train_dataset = load_dataset('json', data_files='./data/train_TLQA.json')['train']
val_dataset = load_dataset('json', data_files='./data/val_TLQA.json')['train']

In [3]:
# Load the Flan-T5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-base')
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-base')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question

    # Convert inputs and labels to tensors
    input_ids_tensor = torch.tensor(model_inputs['input_ids'])
    labels_tensor = torch.tensor(model_inputs['labels'])

    # Check for NaN values
    if torch.isnan(input_ids_tensor).any() or torch.isnan(labels_tensor).any():
        raise ValueError("NaN values found in input data")

    return model_inputs


tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)

In [5]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [6]:
training_args = TrainingArguments(
    output_dir="./results-FlanT5-base",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
)



In [7]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

In [8]:
# Train the model
trainer.train()

  0%|          | 0/2409 [00:00<?, ?it/s]

{'loss': 1.82, 'grad_norm': 0.40136584639549255, 'learning_rate': 3.9622249896222504e-05, 'epoch': 0.62}


  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.10614319145679474, 'eval_runtime': 2145.0468, 'eval_samples_per_second': 0.666, 'eval_steps_per_second': 0.166, 'epoch': 1.0}
{'loss': 0.1245, 'grad_norm': 0.19393013417720795, 'learning_rate': 2.9244499792445003e-05, 'epoch': 1.25}
{'loss': 0.116, 'grad_norm': 0.16061444580554962, 'learning_rate': 1.8866749688667497e-05, 'epoch': 1.87}


  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.10266786813735962, 'eval_runtime': 2129.3576, 'eval_samples_per_second': 0.671, 'eval_steps_per_second': 0.168, 'epoch': 2.0}
{'loss': 0.1145, 'grad_norm': 0.2039393037557602, 'learning_rate': 8.488999584889996e-06, 'epoch': 2.49}


  0%|          | 0/357 [00:00<?, ?it/s]

{'eval_loss': 0.10166715085506439, 'eval_runtime': 1420.6897, 'eval_samples_per_second': 1.005, 'eval_steps_per_second': 0.251, 'epoch': 3.0}
{'train_runtime': 41531.4605, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'train_loss': 0.47009055809442213, 'epoch': 3.0}


TrainOutput(global_step=2409, training_loss=0.47009055809442213, metrics={'train_runtime': 41531.4605, 'train_samples_per_second': 0.232, 'train_steps_per_second': 0.058, 'total_flos': 6598321848188928.0, 'train_loss': 0.47009055809442213, 'epoch': 3.0})

In [9]:
# Save the model
model.save_pretrained('./results-FlanT5-base')
tokenizer.save_pretrained('./results-FlanT5-base')

('./results-FlanT5-base\\tokenizer_config.json',
 './results-FlanT5-base\\special_tokens_map.json',
 './results-FlanT5-base\\spiece.model',
 './results-FlanT5-base\\added_tokens.json')

## Prediction

In [10]:
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import json

In [11]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [12]:
# Preprocess function
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question
    return model_inputs

In [13]:
# Tokenize the test dataset
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/1071 [00:00<?, ? examples/s]



In [14]:
# Perform inference on test set
def generate_answers(batch):
    inputs = batch['input_ids'].to(model.device)
    outputs = model.generate(
        inputs,
        max_length=512,
        num_beams=5,
        early_stopping=True,
        repetition_penalty=2.5,
        length_penalty=1.0,
        no_repeat_ngram_size=2,
        num_return_sequences=1  # Generate a single sequence per input
    )
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output.split("; ") for output in decoded_outputs]


In [15]:
predictions = []
for batch in trainer.get_test_dataloader(tokenized_test):
    batch_predictions = generate_answers(batch)
    predictions.extend(batch_predictions)

In [16]:
# Save predictions to a JSON file
with open('predictions-FlanT5-base.json', 'w') as f:
    json.dump(predictions, f)

## Evaluation

In [1]:
from datasets import load_dataset
import json

In [12]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
import re

# Define metric functions
def entity_match(predictions, references):
    scores = []
    for pred_list, ref_list in zip(predictions, references):
        pred_entities = set([re.sub(r'\s*\(.*?\)', '', pred.strip().lower()) for pred in pred_list])
        ref_entities = set([re.sub(r'\s*\(.*?\)', '', ref.strip().lower()) for ref in ref_list])

        # Debug prints
        # print("Pred Entities:", pred_entities)
        # print("Ref Entities:", ref_entities)

        matches = len(pred_entities & ref_entities)  # Intersection of sets
        scores.append(matches / len(ref_entities) if ref_entities else 0)
    return sum(scores) / len(scores)


def timeline_match(predictions, references):
    def extract_years(text):
        matches = re.findall(r'\d{4}', text)
        if matches:
            return set(map(int, matches))
        return set()

    scores = []
    for pred_list, ref_list in zip(predictions, references):
        match_scores = []
        for i, ref in enumerate(ref_list):
            ref_years = extract_years(ref)
            try:
                pred_years = extract_years(pred_list[i])
                # print(f"Matching prediction timeline {i + 1} with reference timeline {i + 1}")
                # print("Pred Years:", pred_years)
                # print("Ref Years:", ref_years)
                matches = len(pred_years & ref_years)
                match_scores.append(matches / len(ref_years) if ref_years else 0)
            except IndexError:
                # print(f"No prediction timeline for reference timeline {i + 1}")
                match_scores.append(0)
        scores.append(sum(match_scores) / len(match_scores) if match_scores else 0)
    return sum(scores) / len(scores)

def f1_metric(predictions, references):
    def compute_f1(pred_list, ref_list):
        pred_tokens = set(pred_list)
        ref_tokens = set(ref_list)

        if not ref_tokens:
            return 0.0

        true_positives = len(pred_tokens & ref_tokens)
        precision = true_positives / len(pred_tokens) if pred_tokens else 0
        recall = true_positives / len(ref_tokens) if ref_tokens else 0

        if precision + recall == 0:
            return 0.0

        return 2 * (precision * recall) / (precision + recall)

    f1_scores = []

    for pred_list, ref_list in zip(predictions, references):
        # Flatten the lists for token-level comparison
        all_preds = [token for pred in pred_list for token in pred.split()]
        all_refs = [token for ref in ref_list for token in ref.split()]
        f1_scores.append(compute_f1(all_preds, all_refs))

    return sum(f1_scores) / len(f1_scores) if f1_scores else 0

def extract_years(text):
    matches = re.findall(r'\d{4}', text)
    if matches:
        years = list(map(int, matches))
        return min(years), max(years)  # Return the earliest and latest years
    return None, None

def time_metric(predictions, references):
    time_diffs = []
    
    for pred_list, ref_list in zip(predictions, references):
        for pred, ref in zip(pred_list, ref_list):
            pred_start, pred_end = extract_years(pred)
            ref_start, ref_end = extract_years(ref)
            
            # print(f"Prediction: {pred}, Extracted Years: ({pred_start}, {pred_end})")  # Debugging
            # print(f"Reference: {ref}, Extracted Years: ({ref_start}, {ref_end})")  # Debugging
            
            if pred_start is not None and ref_start is not None:
                time_diff = abs(pred_start - ref_start) + abs(pred_end - ref_end)
                time_diffs.append(time_diff)
            else:
                time_diffs.append(float('inf'))  # Invalid prediction or reference

    valid_diffs = [diff for diff in time_diffs if diff != float('inf')]
    if not valid_diffs:
        return float('inf')  # If all are invalid, return inf
    return sum(valid_diffs) / len(valid_diffs) # Average time difference

def completeness(predictions, references):
    scores = []
    
    for pred_list, ref_list in zip(predictions, references):
        pred_items = set([item.strip().lower() for pred in pred_list for item in pred.split(", ")])
        ref_items = set([item.strip().lower() for ref in ref_list for item in ref.split(", ")])
        
        correct_count = len(pred_items.intersection(ref_items))
        total_count = len(ref_items)
        
        scores.append(correct_count / total_count if total_count > 0 else 0)
    return sum(scores) / len(scores)

In [3]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [4]:
# References
references = [ans for ans in test_dataset['answers']]  # Ensure references are lists of answers

In [5]:
# Load predictions from a JSON file
with open('predictions-FlanT5-base.json', 'r') as f:
    predictions = json.load(f)

In [6]:
# Debug: Print some predictions and references to ensure alignment
for pred, ref in zip(predictions, references):
    print("Prediction:", pred)
    print("Reference:", ref)

Prediction: ['England national association football team (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'West Bromwich Albion F.C. (2011, 2012)']
Reference: ['Southend United F.C. (2010, 2011, 2012)', 'Stevenage F.C. (2012, 2013)', 'Crewe Alexandra F.C. (2013, 2014, 2015)', 'Port Vale F.C. (2015, 2016, 2017, 2018, 2019, 2020)']
Prediction: ['President of the Russian Federation (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'Prime Minister of Russia (2020)']
Reference: ['Prime Minister of Ukraine (2010)', 'First Deputy Prime Minister of Ukraine (2010)', "People's Deputy of Ukraine (2012, 2013, 2014, 2015)", 'Chairman of the Verkhovna Rada (2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'President of Ukraine (2014)']
Prediction: ['Democratic Party of Ireland (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)']
Reference: ['Socialist Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019)', 'RISE (Ireland) (2019, 2020)']


In [7]:
# Ensure predictions and references have the same length
predictions = predictions[:len(references)]

In [13]:
# Evaluate predictions with custom metrics
entity_score = entity_match(predictions, references)
timeline_score = timeline_match(predictions, references)
f1 = f1_metric(predictions, references)
time_metric_score = time_metric(predictions, references)
completeness_score = completeness(predictions, references)

print(f"Entity Match: {entity_score}")
print(f"Timeline Match: {timeline_score}")
print(f"F1 Score: {f1}")
print(f"Time Metric: {time_metric_score}")
print(f"Completeness: {completeness_score}")

Entity Match: 0.02174240066396929
Timeline Match: 0.4359764076570791
F1 Score: 0.5376337207637637
Time Metric: 4.389843166542196
Completeness: 0.5761945313051584
