In [2]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.neighbors import NearestNeighbors
import requests
import json
import os
import re
from sklearn.metrics import f1_score as sklearn_f1_score
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple, Dict

In [3]:
#load the test answer data set
with open('true_answers.json', 'r') as f:
    true_answers = json.load(f)

In [22]:
def load_predictions(file_path: str) -> List[str]:
    with open(file_path, 'r') as file:
        predictions = json.load(file)
    return predictions

In [30]:
# parse the data set
def parse_years(years_str: str) -> List[int]:
    
    years_str = years_str.replace('–', '-').replace('—', '-').replace('−', '-')
    years = []
    for part in years_str.split(', '):
        if '-' in part:
            try:
                start, end = map(int, part.split('-'))
                years.extend(range(start, end + 1))
            except ValueError as e:
                continue
                #print(f"Skipping invalid year range part: {part}. Error: {e}")
        else:
            try:
                years.append(int(part))
            except ValueError as e:
                #print(f"Skipping invalid year part: {part}. Error: {e}")
                continue
    return years

def parse_answer(answer: str) -> Tuple[List[str], Dict[str, List[int]]]:
    entities = []
    timelines = {}
    parts = re.split(r'\),\s*', answer)
    
    parts = [part + ')' if not part.endswith(')') else part for part in parts]

    for part in parts:
        entity_match = re.match(r'(.+?)\s+\((.+)\)', part)
        if entity_match:
            entity = entity_match.group(1).strip()
            years_str = entity_match.group(2).strip()
            years = parse_years(years_str)
            entities.append(entity)
            timelines[entity] = years
        else:
            continue
            #print(f"Skipping invalid part: {part}")
    return entities, timelines

In [27]:
# EM of entities
def evaluate_entities_em(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_entities = 0
    matching_entities = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        total_entities += len(gt_entities)
        matching_entities += len(set(pred_entities).intersection(set(gt_entities)))

    entity_em_score = matching_entities / total_entities if total_entities > 0 else 0
    return entity_em_score

# EM of timeline 
def evaluate_timeline_em(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_entities = 0
    matching_timelines = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, pred_timelines = pred
        gt_entities, gt_timelines = gt

        for entity in gt_entities:
            if entity in pred_entities and entity in pred_timelines and entity in gt_timelines:
                if pred_timelines[entity] == gt_timelines[entity]:
                    matching_timelines += 1
                total_entities += 1

    timeline_em_score = matching_timelines / total_entities if total_entities > 0 else 0
    return timeline_em_score


# Completness
def evaluate_completeness(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_completeness = 0
    num_questions = len(parsed_true_answers)

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        correct_entities = set(pred_entities).intersection(set(gt_entities))
        completeness = len(correct_entities) / len(gt_entities) if len(gt_entities) > 0 else 0
        total_completeness += completeness

    average_completeness = total_completeness / num_questions if num_questions > 0 else 0
    return average_completeness

In [9]:
# F1 score 

def calculate_entity_precision_recall(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> Tuple[float, float, int]:
    true_positives = 0
    predicted_entities = 0
    actual_entities = 0
    correct_entities_set = []

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        correct_entities = set(pred_entities).intersection(set(gt_entities))
        true_positives += len(correct_entities)
        predicted_entities += len(pred_entities)
        actual_entities += len(gt_entities)

        correct_entities_set.append(correct_entities)

    precision = true_positives / predicted_entities if predicted_entities > 0 else 0
    recall = true_positives / actual_entities if actual_entities > 0 else 0
    return precision, recall, correct_entities_set

def calculate_timeline_precision_recall(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]], correct_entities_set: List[set]) -> Tuple[float, float]:
    true_positives = 0
    predicted_timelines = 0
    actual_timelines = 0

    for pred, gt, correct_entities in zip(parsed_predictions, parsed_true_answers, correct_entities_set):
        _, pred_timelines = pred
        _, gt_timelines = gt

        for entity in correct_entities:
            if entity in pred_timelines and entity in gt_timelines:
                if pred_timelines[entity] == gt_timelines[entity]:
                    true_positives += 1
                predicted_timelines += 1
                actual_timelines += 1

    precision = true_positives / predicted_timelines if predicted_timelines > 0 else 0
    recall = true_positives / actual_timelines if actual_timelines > 0 else 0
    return precision, recall

def calculate_f1_score(precision: float, recall: float) -> float:
    if precision + recall == 0:
        return 0
    return 2 * (precision * recall) / (precision + recall)


In [24]:
def calculate_scores(predictions: List[str], true_answers: List[List[str]]) -> Tuple[float, float, float]:
    parsed_predictions = [parse_answer(pred) for pred in predictions]
    parsed_true_answers = [
        (
            list(map(lambda x: re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(), ans)), 
            {re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(): parse_years(re.match(r'(.+?)\s+\((.+)\)', x).group(2).strip()) for x in ans}
        ) 
        for ans in true_answers
    ]

    entity_precision, entity_recall, correct_entities_set = calculate_entity_precision_recall(parsed_predictions, parsed_true_answers)
    entity_f1_score = calculate_f1_score(entity_precision, entity_recall)

    timeline_precision, timeline_recall = calculate_timeline_precision_recall(parsed_predictions, parsed_true_answers, correct_entities_set)
    timeline_f1_score = calculate_f1_score(timeline_precision, timeline_recall)

    completeness_score = evaluate_completeness(parsed_predictions, parsed_true_answers)
    
    return entity_f1_score, timeline_f1_score, completeness_score


In [28]:
parsed_true_answers = [
    (
        list(map(lambda x: re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(), ans)), 
        {re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(): parse_years(re.match(r'(.+?)\s+\((.+)\)', x).group(2).strip()) for x in ans}
    ) 
    for ans in true_answers
]

In [31]:
# closed book
# withouot fine-tuning
# Evaluation of flan xl

files_xl = [
    "predictions_3shots_flanxl.json",
    "predictions_5shots_flanxl.json",
    "predictions_7shots_flanxl.json",
    "predictions_10shots_flanxl.json"
]

for file in files_xl:
    
    predictions = load_predictions(file)
    parsed_predictions = [parse_answer(pred) for pred in predictions]
    entity_em = evaluate_entities_em(parsed_predictions, parsed_true_answers)
    
    timeline_em = evaluate_timeline_em(parsed_predictions, parsed_true_answers)

    entity_f1, timeline_f1, completeness = calculate_scores(predictions, true_answers)
    print(f"File: {file}\nEntity EM Score: {entity_em}\nTimeline EM Score:{timeline_em}\nEntity F1 Score: {entity_f1}\nTimeline F1 Score: {timeline_f1}\nCompleteness: {completeness}\n")


File: predictions_3shots_flanxl.json
Entity EM Score: 0.036346396965865994
Timeline EM Score:0.1810344827586207
Entity F1 Score: 0.04353586977096347
Timeline F1 Score: 0.1826086956521739
Completeness: 0.03968470579114838

File: predictions_5shots_flanxl.json
Entity EM Score: 0.040455120101137804
Timeline EM Score:0.21875
Entity F1 Score: 0.04824726724462872
Timeline F1 Score: 0.21875
Completeness: 0.04379566641471405

File: predictions_7shots_flanxl.json
Entity EM Score: 0.03666245259165613
Timeline EM Score:0.22413793103448276
Entity F1 Score: 0.04408132243967319
Timeline F1 Score: 0.22413793103448276
Completeness: 0.04389231476066211

File: predictions_10shots_flanxl.json
Entity EM Score: 0.00695322376738306
Timeline EM Score:0.13636363636363635
Entity F1 Score: 0.009411764705882352
Timeline F1 Score: 0.13636363636363635
Completeness: 0.008564166407303661



In [36]:
# Evaluation of flan large

files_large = [
    "predictions_3shot_flanlarge.json",
    "predictions_5shot_flanlarge.json",
    "predictions_7shot_flanlarge.json",
    "predictions_10shot_flanlarge.json"
]

for file in files_large:
    
    predictions = load_predictions(file)
    parsed_predictions = [parse_answer(pred) for pred in predictions]
    entity_em = evaluate_entities_em(parsed_predictions, parsed_true_answers)
    
    timeline_em = evaluate_timeline_em(parsed_predictions, parsed_true_answers)

    entity_f1, timeline_f1, completeness = calculate_scores(predictions, true_answers)
    print(f"File: {file}\nEntity EM Score: {entity_em}\nTimeline EM Score:{timeline_em}\nEntity F1 Score: {entity_f1}\nTimeline F1 Score: {timeline_f1}\nCompleteness: {completeness}\n")


File: predictions_3shot_flanlarge.json
Entity EM Score: 0.02307206068268015
Timeline EM Score:0.1917808219178082
Entity F1 Score: 0.022461538461538463
Timeline F1 Score: 0.1917808219178082
Completeness: 0.023889185303751122

File: predictions_5shot_flanlarge.json
Entity EM Score: 0.021491782553729456
Timeline EM Score:0.27941176470588236
Entity F1 Score: 0.026003824091778205
Timeline F1 Score: 0.27941176470588236
Completeness: 0.02448090347250011

File: predictions_7shot_flanlarge.json
Entity EM Score: 0.02054361567635904
Timeline EM Score:0.16923076923076924
Entity F1 Score: 0.01935099732063114
Timeline F1 Score: 0.16923076923076924
Completeness: 0.024513509107346638

File: predictions_10shot_flanlarge.json
Entity EM Score: 0.008217446270543615
Timeline EM Score:0.15384615384615385
Entity F1 Score: 0.010509296685529508
Timeline F1 Score: 0.15384615384615385
Completeness: 0.00980132793858284

