In [25]:
import torch
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from sentence_transformers import SentenceTransformer, util
import numpy as np
from sklearn.neighbors import NearestNeighbors
import requests
import json
import os
import re
from sklearn.metrics import f1_score as sklearn_f1_score
from sklearn.metrics.pairwise import cosine_similarity
from typing import List, Tuple, Dict
from collections import Counter
import string

In [37]:
#load the test answer data set
with open('true_answers.json', 'r') as f:
    true_answers = json.load(f)

In [4]:
def load_predictions(file_path: str) -> List[str]:
    with open(file_path, 'r') as file:
        predictions = json.load(file)
    return predictions

In [5]:
# parse the data set
def parse_years(years_str: str) -> List[int]:
    
    years_str = years_str.replace('–', '-').replace('—', '-').replace('−', '-')
    # 处理 "since" 关键字
    years_str = re.sub(r'since (\d{4})', r'\1-', years_str)
    years = []
    for part in years_str.split(', '):
        if '-' in part:
            try:
                start, end = map(int, part.split('-'))
                years.extend(range(start, end + 1))
            except ValueError as e:
                continue
                #print(f"Skipping invalid year range part: {part}. Error: {e}")
        else:
            try:
                years.append(int(part))
            except ValueError as e:
                #print(f"Skipping invalid year part: {part}. Error: {e}")
                continue
    return years

def parse_answer(answer: str) -> Tuple[List[str], Dict[str, List[int]]]:
    entities = []
    timelines = {}
    parts = re.split(r'\),\s*', answer)
    
    parts = [part + ')' if not part.endswith(')') else part for part in parts]

    for part in parts:
        entity_match = re.match(r'(.+?)\s+\((.+)\)', part)
        if entity_match:
            entity = entity_match.group(1).strip()
            years_str = entity_match.group(2).strip()
            years = parse_years(years_str)
            entities.append(entity)
            timelines[entity] = years
        else:
            continue
            #print(f"Skipping invalid part: {part}")
    return entities, timelines

In [6]:
# EM of entities
def evaluate_entities_em(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_questions = 0
    matching_questions = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        if set(pred_entities) == set(gt_entities) and set(pred_entities):
            matching_questions += 1
            print(f"Matched entities: {set(pred_entities)}")
        total_questions += 1

    entity_em_score = matching_questions / total_questions if total_questions > 0 else 0
    return entity_em_score


# EM of timeline 
def evaluate_timeline_em(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_entities = 0
    matching_timelines = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, pred_timelines = pred
        gt_entities, gt_timelines = gt

        if set(pred_entities) == set(gt_entities) and set(pred_entities):
            for entity in gt_entities:
                if entity in pred_entities and entity in pred_timelines and entity in gt_timelines:
                    if pred_timelines[entity] == gt_timelines[entity]:
                        matching_timelines += 1
                    total_entities += 1

    timeline_em_score = matching_timelines / total_entities if total_entities > 0 else 0
    return timeline_em_score




In [45]:
def evaluate_completeness(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_completeness = 0
    num_questions = len(parsed_true_answers)

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, pred_timelines = pred
        gt_entities, gt_timelines = gt

        
        correct_entities = set(pred_entities).intersection(set(gt_entities))
        correct_timelines = 0
        total_gt_timelines = len(gt_entities)  

        for entity in correct_entities:
            if entity in pred_timelines and entity in gt_timelines:
                if pred_timelines[entity] == gt_timelines[entity]:
                    correct_timelines += 1

        completeness = correct_timelines / total_gt_timelines if total_gt_timelines > 0 else 0
        total_completeness += completeness

    average_completeness = total_completeness / num_questions if num_questions > 0 else 0
    return average_completeness

In [49]:
def evaluate_entity_completeness(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> float:
    total_completeness = 0
    num_questions = len(parsed_true_answers)

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        correct_entities = set(pred_entities).intersection(set(gt_entities))
        total_gt_entities = len(gt_entities)  

        completeness = len(correct_entities) / total_gt_entities if total_gt_entities > 0 else 0
        total_completeness += completeness

    average_completeness = total_completeness / num_questions if num_questions > 0 else 0
    return average_completeness



In [21]:
# F1 score 

def evaluate_entity_f1(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> Tuple[float, float, float]:
    total_true_positives = 0
    total_predicted = 0
    total_actual = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, _ = pred
        gt_entities, _ = gt

        true_positives = len(set(pred_entities).intersection(set(gt_entities)))
        total_true_positives += true_positives
        total_predicted += len(pred_entities)
        total_actual += len(gt_entities)

    return precision_recall_f1(total_true_positives, total_predicted, total_actual)

def evaluate_timeline_f1(parsed_predictions: List[Tuple[List[str], Dict[str, List[int]]]], parsed_true_answers: List[Tuple[List[str], Dict[str, List[int]]]]) -> Tuple[float, float, float]:
    total_true_positives = 0
    total_predicted = 0
    total_actual = 0

    for pred, gt in zip(parsed_predictions, parsed_true_answers):
        pred_entities, pred_timelines = pred
        gt_entities, gt_timelines = gt

        for entity in gt_entities:
            if entity in pred_entities and entity in pred_timelines and entity in gt_timelines:
                if pred_timelines[entity] == gt_timelines[entity]:
                    total_true_positives += 1
                total_actual += 1

        total_predicted += len(pred_timelines)

    return precision_recall_f1(total_true_positives, total_predicted, total_actual)

def precision_recall_f1(true_positives, predicted, actual):
    precision = true_positives / predicted if predicted > 0 else 0
    recall = true_positives / actual if actual > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0
    return precision, recall, f1

def process_answer(answers):
    
    final_answer = ""
    for answer in answers:
        
        final_answer+= (answer+' ')

    return final_answer


def normalize_answer(s):
    """
    Lower text and remove punctuation, articles and extra whitespace.

    Args:
        s: String to normalize.

    Returns:
        Cleaned string with lowercase, no punctuations, no articles, and
            and extraneous whitespace.
    """

    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    """Calculates F1 score.

    Args:
        prediction: Predicted answer span (string).
        ground_truth: True answer span (string).

    Returns:
        F1 score.
    """
    prediction_tokens = normalize_answer(prediction).split()

    ground_truth = process_answer(ground_truth)
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

In [50]:
def calculate_scores(predictions: List[str], true_answers: List[List[str]]) -> Tuple[float, float, float]:
    parsed_predictions = [parse_answer(pred) for pred in predictions]
    parsed_true_answers = [
        (
            list(map(lambda x: re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(), ans)), 
            {re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(): parse_years(re.match(r'(.+?)\s+\((.+)\)', x).group(2).strip()) for x in ans}
        ) 
        for ans in true_answers
    ]

    
    overall_f1_score = f1_score(" ".join(predictions), [" ".join(ans) for ans in true_answers])
    _, _, entity_f1_score = evaluate_entity_f1(parsed_predictions, parsed_true_answers)
    
    _, _, timeline_f1_score = evaluate_timeline_f1(parsed_predictions, parsed_true_answers)
    
    
    entities_em_score = evaluate_entities_em(parsed_predictions, parsed_true_answers)
    
    timeline_em_score =  evaluate_timeline_em(parsed_predictions, parsed_true_answers)


    completeness_score = evaluate_completeness(parsed_predictions, parsed_true_answers)
    entity_completeness_score = evaluate_entity_completeness(parsed_predictions, parsed_true_answers)
    
    return overall_f1_score, entity_f1_score, timeline_f1_score, entities_em_score, timeline_em_score, completeness_score, entity_completeness_score


In [69]:
parsed_true_answers = [
    (
        list(map(lambda x: re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(), ans)), 
        {re.match(r'(.+?)\s+\((.+)\)', x).group(1).strip(): parse_years(re.match(r'(.+?)\s+\((.+)\)', x).group(2).strip()) for x in ans}
    ) 
    for ans in true_answers
]

In [52]:
# closed book
# withouot fine-tuning
# Evaluation of flan xl

files_xl = [
    "predictions_3shots_flanxl.json",
    "predictions_5shots_flanxl.json",
    "predictions_7shots_flanxl.json",
    "predictions_10shots_flanxl.json"
]

for file in files_xl:
    
    predictions = load_predictions(file)
    
    overall_f1, entity_f1, timeline_f1, entity_em, timeline_em, completeness,entity_com = calculate_scores(predictions, true_answers)
    print(f"File: {file}\nEntity EM Score: {entity_em}\nTimeline EM Score:{timeline_em}\nOverall F1 Score: {overall_f1}\nEntity F1 Score: {entity_f1}\nTimeline F1 Score: {timeline_f1}\nCompleteness: {completeness}\nEntity Completness: {entity_com}\n")


Matched entities: {'Jim Prentice', 'Rachel Notley', 'Jason Kenney'}
Matched entities: {'Democratic Party', 'Republican Party'}
Matched entities: {'United States representative', 'United States senator'}
Matched entities: {'Chief of the Defence Staff', 'Chief of the General Staff'}
File: predictions_3shots_flanxl.json
Entity EM Score: 0.003734827264239029
Timeline EM Score:0.0
Overall F1 Score: 0.5571964956195243
Entity F1 Score: 0.04353586977096347
Timeline F1 Score: 0.020537897310513448
Completeness: 0.007847583477835578
Entity Completness: 0.03968470579114838

Matched entities: {'Democratic Party', 'Republican Party'}
Matched entities: {'United States representative', 'United States senator'}
Matched entities: {'Chief of the Defence Staff', 'Chief of the General Staff'}
Matched entities: {'Serhiy Arbuzov', 'Yulia Tymoshenko', 'Arseniy Yatsenyuk', 'Oleksiy Honcharuk', 'Mykola Azarov', 'Denys Shmyhal', 'Volodymyr Groysman'}
File: predictions_5shots_flanxl.json
Entity EM Score: 0.003734

In [54]:
# Evaluation of flan large

files_large = [
    "predictions_3shot_flanlarge.json",
    "predictions_5shot_flanlarge.json",
    "predictions_7shot_flanlarge.json",
    "predictions_10shot_flanlarge.json"
]

for file in files_large:
    
    predictions = load_predictions(file)
    
    overall_f1, entity_f1, timeline_f1, entity_em, timeline_em, completeness,entity_com = calculate_scores(predictions, true_answers)
    print(f"File: {file}\nEntity EM Score: {entity_em}\nTimeline EM Score:{timeline_em}\nOverall F1 Score: {overall_f1}\nEntity F1 Score: {entity_f1}\nTimeline F1 Score: {timeline_f1}\nCompleteness: {completeness}\nEntity Completness: {entity_com}\n")


File: predictions_3shot_flanlarge.json
Entity EM Score: 0.0
Timeline EM Score:0
Overall F1 Score: 0.5633279683050638
Entity F1 Score: 0.022461538461538463
Timeline F1 Score: 0.008281573498964804
Completeness: 0.005239132690113081
Entity Completness: 0.023889185303751122

File: predictions_5shot_flanlarge.json
Entity EM Score: 0.0
Timeline EM Score:0
Overall F1 Score: 0.5188428724786711
Entity F1 Score: 0.026003824091778205
Timeline F1 Score: 0.018590998043052837
Completeness: 0.008092125739184563
Entity Completness: 0.02448090347250011

File: predictions_7shot_flanlarge.json
Entity EM Score: 0.0
Timeline EM Score:0
Overall F1 Score: 0.57602983556535
Entity F1 Score: 0.01935099732063114
Timeline F1 Score: 0.00613325899080011
Completeness: 0.004668534080298785
Entity Completness: 0.024513509107346638

Matched entities: {'Libertarian Party', 'Republican Party'}
Matched entities: {'United States representative', 'United States senator'}
File: predictions_10shot_flanlarge.json
Entity EM Sco