## Training

In [10]:
import torch
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import re
import json

In [11]:
# Load datasets
train_dataset = load_dataset('json', data_files='./data/train_TLQA.json')['train']
val_dataset = load_dataset('json', data_files='./data/val_TLQA.json')['train']

In [12]:
# Load the Flan-T5 model and tokenizer
tokenizer = T5Tokenizer.from_pretrained('google/flan-t5-small')
model = T5ForConditionalGeneration.from_pretrained('google/flan-t5-small')

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [13]:
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question

    # Convert inputs and labels to tensors
    input_ids_tensor = torch.tensor(model_inputs['input_ids'])
    labels_tensor = torch.tensor(model_inputs['labels'])

    # Check for NaN values
    if torch.isnan(input_ids_tensor).any() or torch.isnan(labels_tensor).any():
        raise ValueError("NaN values found in input data")

    return model_inputs


tokenized_train = train_dataset.map(preprocess_function, batched=True)
tokenized_val = val_dataset.map(preprocess_function, batched=True)

In [14]:
# Data collator
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [15]:
training_args = TrainingArguments(
    output_dir="./results-FlanT5-small",
    evaluation_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [16]:
# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
)

In [17]:
# Train the model
trainer.train()

  0%|          | 0/2409 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [27]:
# Save the model
model.save_pretrained('./results-FlanT5-small')
tokenizer.save_pretrained('./results-FlanT5-small')

('./results-FlanT5-small\\tokenizer_config.json',
 './results-FlanT5-small\\special_tokens_map.json',
 './results-FlanT5-small\\spiece.model',
 './results-FlanT5-small\\added_tokens.json')

## Prediction

In [10]:
from datasets import load_dataset
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import json

In [12]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [13]:
# Preprocess function
def preprocess_function(examples):
    inputs = examples['question']
    targets = ["; ".join(ans) for ans in examples['answers']]  # Join list of answers with a separator
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length")

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length")

    model_inputs['labels'] = labels['input_ids']
    model_inputs['question'] = inputs  # Keep the original question
    return model_inputs

In [14]:
# Tokenize the test dataset
tokenized_test = test_dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/1071 [00:00<?, ? examples/s]



In [15]:
# Perform inference on test set
def generate_answers(batch):
    inputs = batch['input_ids'].to(model.device)
    outputs = model.generate(
        inputs,
        max_length=512,
        num_beams=5,
        early_stopping=True,
        repetition_penalty=2.5,
        length_penalty=1.0,
        no_repeat_ngram_size=2,
        num_return_sequences=1  # Generate a single sequence per input
    )
    decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return [output.split("; ") for output in decoded_outputs]


In [16]:
predictions = []
for batch in trainer.get_test_dataloader(tokenized_test):
    batch_predictions = generate_answers(batch)
    predictions.extend(batch_predictions)

In [17]:
# Save predictions to a JSON file
with open('predictions-FlanT5-small.json', 'w') as f:
    json.dump(predictions, f)

## Evaluation

In [18]:
from datasets import load_dataset
import json

In [19]:
from sklearn.metrics import f1_score
from sklearn.metrics import precision_recall_fscore_support
import re

# Define metric functions
def exact_match(predictions, references):
    scores = []
    for pred_list, ref_list in zip(predictions, references):
        # Convert to lower case and strip spaces for comparison
        pred_set = set([pred.strip().lower() for pred in pred_list])
        ref_set = set([ref.strip().lower() for ref in ref_list])
        exact_matches = len(pred_set & ref_set) # Intersection of sets
        scores.append(exact_matches / len(ref_set) if ref_set else 0) 
    return sum(scores) / len(scores) 

def f1_metric(predictions, references):
    def compute_f1(pred_list, ref_list):
        pred_tokens = set(pred_list)
        ref_tokens = set(ref_list)

        if not ref_tokens:
            return 0.0

        true_positives = len(pred_tokens & ref_tokens)
        precision = true_positives / len(pred_tokens) if pred_tokens else 0
        recall = true_positives / len(ref_tokens) if ref_tokens else 0

        if precision + recall == 0:
            return 0.0

        return 2 * (precision * recall) / (precision + recall)

    f1_scores = []

    for pred_list, ref_list in zip(predictions, references):
        # Flatten the lists for token-level comparison
        all_preds = [token for pred in pred_list for token in pred.split()]
        all_refs = [token for ref in ref_list for token in ref.split()]
        f1_scores.append(compute_f1(all_preds, all_refs))

    return sum(f1_scores) / len(f1_scores) if f1_scores else 0

def extract_years(text):
    matches = re.findall(r'\d{4}', text)
    if matches:
        years = list(map(int, matches))
        return min(years), max(years)  # Return the earliest and latest years
    return None, None

def time_metric(predictions, references):
    time_diffs = []
    
    for pred_list, ref_list in zip(predictions, references):
        for pred, ref in zip(pred_list, ref_list):
            pred_start, pred_end = extract_years(pred)
            ref_start, ref_end = extract_years(ref)
            
            print(f"Prediction: {pred}, Extracted Years: ({pred_start}, {pred_end})")  # Debugging
            print(f"Reference: {ref}, Extracted Years: ({ref_start}, {ref_end})")  # Debugging
            
            if pred_start is not None and ref_start is not None:
                time_diff = abs(pred_start - ref_start) + abs(pred_end - ref_end)
                time_diffs.append(time_diff)
            else:
                time_diffs.append(float('inf'))  # Invalid prediction or reference

    valid_diffs = [diff for diff in time_diffs if diff != float('inf')]
    if not valid_diffs:
        return float('inf')  # If all are invalid, return inf
    return sum(valid_diffs) / len(valid_diffs) # Average time difference

def completeness(predictions, references):
    scores = []
    
    for pred_list, ref_list in zip(predictions, references):
        pred_items = set([item.strip().lower() for pred in pred_list for item in pred.split(", ")])
        ref_items = set([item.strip().lower() for ref in ref_list for item in ref.split(", ")])
        
        correct_count = len(pred_items.intersection(ref_items))
        total_count = len(ref_items)
        
        scores.append(correct_count / total_count if total_count > 0 else 0)
    return sum(scores) / len(scores)

In [None]:
# Load the test dataset
test_dataset = load_dataset('json', data_files='./data/test_TLQA.json')['train']

In [20]:
# References
references = [ans for ans in test_dataset['answers']]  # Ensure references are lists of answers

In [None]:
# Load predictions from a JSON file
with open('predictions-FlanT5-small.json', 'r') as f:
    predictions = json.load(f)

In [24]:
# Debug: Print some predictions and references to ensure alignment
for pred, ref in zip(predictions, references):
    print("Prediction:", pred)
    print("Reference:", ref)

Prediction: ['Manchester City F.C. (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)']
Reference: ['Southend United F.C. (2010, 2011, 2012)', 'Stevenage F.C. (2012, 2013)', 'Crewe Alexandra F.C. (2013, 2014, 2015)', 'Port Vale F.C. (2015, 2016, 2017, 2018, 2019, 2020)']
Prediction: ['Member of the Russian Parliament (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'Minister of Foreign Affairs (2019, 2020)']
Reference: ['Prime Minister of Ukraine (2010)', 'First Deputy Prime Minister of Ukraine (2010)', "People's Deputy of Ukraine (2012, 2013, 2014, 2015)", 'Chairman of the Verkhovna Rada (2014, 2015, 2016, 2017, 2018, 2019, 2020)', 'President of Ukraine (2014)']
Prediction: ['Democratic Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020)']
Reference: ['Socialist Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019)', 'RISE (Ireland) (2019, 2020)']
Prediction: ['Manchester City F.C. (2010, 2011, 2012, 2013, 2014, 201

In [25]:
# Ensure predictions and references have the same length
predictions = predictions[:len(references)]

In [26]:
# Evaluate predictions with custom metrics
em_scores = [exact_match(pred, ref) for pred, ref in zip(predictions, references)]
em = sum(em_scores) / len(em_scores)

f1 = f1_metric(predictions, references)
time_metric_score = time_metric(predictions, references)
completeness_score = completeness(predictions, references)

print(f"Exact Match (EM): {em}")
print(f"F1 Score: {f1}")
print(f"Time Metric: {time_metric_score}")
print(f"Completeness: {completeness_score}")

Prediction: Manchester City F.C. (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Southend United F.C. (2010, 2011, 2012), Extracted Years: (2010, 2012)
Prediction: Member of the Russian Parliament (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Prime Minister of Ukraine (2010), Extracted Years: (2010, 2010)
Prediction: Minister of Foreign Affairs (2019, 2020), Extracted Years: (2019, 2020)
Reference: First Deputy Prime Minister of Ukraine (2010), Extracted Years: (2010, 2010)
Prediction: Democratic Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Socialist Party (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019), Extracted Years: (2010, 2019)
Prediction: Manchester City F.C. (2010, 2011, 2012, 2013, 2014, 2015, 2016, 2017, 2018, 2019, 2020), Extracted Years: (2010, 2020)
Reference: Newcastle