In [8]:
import json
from os import path

In [9]:
log_file = "/home/shtoshni/Research/events/models/events_kbp_2015_aec9f8c195882335c0a88616bf5f2a5f/test.log.jsonl"

In [61]:
def f1(p_num, p_den, r_num, r_den, beta=1):
    p = 0 if p_den == 0 else p_num / float(p_den)
    r = 0 if r_den == 0 else r_num / float(r_den)
    return 0 if p + r == 0 else (1 + beta * beta) * p * r / (beta * beta * p + r)

class Evaluator(object):
    def __init__(self, beta=1):
        self.right_coref = 0
        self.wrong_coref = 0
        self.right_non = 0
        self.wrong_non = 0
        self.metric = blanc
        self.beta = beta

    def update(self, predicted, gold):
        rc, wc, rn, wn = self.metric(predicted, gold)
        
        self.right_coref += rc
        self.wrong_coref += wc
        self.right_non += rn
        self.wrong_non += wn

    def get_f1(self):
        beta = self.beta
        
        rc_recall = (self.right_coref)/(self.right_coref + self.wrong_non)
        rc_prec = (self.right_coref)/(self.right_coref + self.wrong_coref)
        
        fc = (1 + beta * beta) * rc_prec * rc_recall / (beta * beta * rc_prec + rc_recall)
        
        rn_prec = (self.right_non)/(self.right_non + self.wrong_non)
        rn_recall = (self.right_non)/(self.right_non + self.wrong_coref)
        fn = (1 + beta * beta) * rn_prec * rn_recall / (beta * beta * rn_prec + rn_recall)
        
        return (fc + fn)/2
        
    def get_recall(self):
        rc_recall = (self.right_coref)/(self.right_coref + self.wrong_non)
        rn_recall = (self.right_non)/(self.right_non + self.wrong_coref)
        
        return (rc_recall + rn_recall)/2

    def get_precision(self):
        rc_prec = (self.right_coref)/(self.right_coref + self.wrong_coref)
        rn_prec = (self.right_non)/(self.right_non + self.wrong_non)
        
        return (rc_prec + rn_prec)/2

    def get_prf(self):
        return self.get_precision(), self.get_recall(), self.get_f1()
    

    
def blanc(predicted, gold):
    
    def get_coref_and_non_coref_links(clusters):
        coref_links = set()
        mentions = set()
        for cluster in clusters:
            for idx, mention1 in enumerate(cluster):
                mentions.add(tuple(mention1))
                for mention2 in cluster[idx + 1:]:
                    link = tuple(sorted([tuple(mention1), tuple(mention2)], key=lambda x: x[0] + 1e-5 * x[1]))
                    coref_links.add(link)
                                        
        non_coref_links = set()
        mentions = sorted(list(mentions), key=lambda x: x[0] + 1e-5 * x[1]) 
        for idx, mention1 in enumerate(mentions):
            for mention2 in mentions[idx + 1:]:
                if (mention1, mention2) in coref_links:
                    continue
                else:
                    non_coref_links.add((mention1, mention2))
                    
                
        return coref_links, non_coref_links
                    
    
    gold_cl, gold_noncl = get_coref_and_non_coref_links(gold)
    predicted_cl, predicted_noncl = get_coref_and_non_coref_links(predicted)
    
    
    rc, wc, rn, wn = 0, 0, 0, 0
    
    for mention_pair in predicted_cl:
        if mention_pair in gold_cl:
            rc += 1
        else:
            wc += 1
    
    for mention_pair in predicted_noncl:
        if mention_pair in gold_noncl:
            rn += 1
        else:
            wn += 1
    
    
    return rc, wc, rn, wn    


def get_clusters(orig_clusters, key="subtype_val"):
    clusters = []

    for orig_cluster in orig_clusters:
        cluster = []
        for (span_start, span_end, mention_info) in orig_cluster:
            cluster.append((span_start, span_end, mention_info[key]))
        clusters.append(cluster)

    return clusters


In [62]:
predicted = [[(1, 5), (2, 6)], [(3, 7), (4, 8), (5, 9)]]
gold = [[(1, 5), (2, 6)], [(3, 7), (4, 8)], [(5, 9)]]
print(blanc(predicted, gold))

(2, 2, 6, 0)


In [63]:
evaluator = Evaluator()

with open(log_file) as f:
    for line in f:
        instance = json.loads(line.strip())
        evaluator.update(instance["predicted_clusters"], get_clusters(instance["clusters"], key="subtype_val"))

In [64]:
print(evaluator.get_f1())

0.38484240701105954
