In [1]:
import json
import numpy as np
import re
from pprint import pprint

## Input Helper Functions

### Noun Extractor

In [2]:
# NOUN POS TAGGING
import nltk 
nltk.download('averaged_perceptron_tagger')
nltk.download('indian')
from nltk.corpus import stopwords 
from nltk.tokenize import word_tokenize, sent_tokenize
def NounExtractor(text):
    # nouns = [(noun, word_idx)...]
    nouns = []
    words = text.split()
    tagged = nltk.pos_tag(words)
    for i,(word, tag) in enumerate(tagged):
        if tag == 'NN' or tag == 'NNS' or tag == 'NNPS' or tag == 'NNP': nouns.append((word,i))

    return nouns
def HindiNounExtractor(text):
    # Tokenize the sentence into a list of words
    words = nltk.word_tokenize(text)

    # Initialize the TnT tagger with the Indian corpus
    tagger = nltk.tag.tnt.TnT()
    tagger.train(nltk.corpus.indian.tagged_sents())

    # Tag the words in the sentence
    tags = tagger.tag(words)

    # Extract the nouns and their indices
    nouns = [(word, index) for index, (word, tag) in enumerate(tags) if tag.startswith('NN')]

    # Print the list of nouns with their indices
    return nouns

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /home/lenovo/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package indian to /home/lenovo/nltk_data...
[nltk_data]   Package indian is already up-to-date!


In [3]:
def add_nouns(datafile, word_annotation_idx, language="hindi"):
    lines = []
    document = ""
    # read file line by line, save lines in lines_list, get all words as list
    with open(datafile) as f:
        for i,line in enumerate(f):
            lines.append(line)

            if line == "\n": continue
            if line.split()[0] == '#begin': 
                # print(line.split()[-1])
                if line.split()[-1] == "nouns_added": 
                    # print("worked")
                    return
                continue
            if line.split()[0] == '#end': continue

            line = line.split()

            word = line[word_annotation_idx]
            document += (word + " ")

    # pass all words to noun extractor
    if language == "hindi": nouns_with_idxs = HindiNounExtractor(document)
    else: nouns_with_idxs = NounExtractor(document)
    noun_idxs = [elt[1] for elt in nouns_with_idxs] 

    # print(nouns_with_idxs)
    word_idx = 0
    # add nouns to end of each line of lines_list, rewrite to file
    for i in range(0,len(lines)-1):
        if i == 0: 
            lines[i] = lines[i][:-1] + " nouns_added\n" 
            continue
        try:
            if word_idx in noun_idxs: lines[i] = lines[i][:-1] + "\tnoun\n"
            else: lines[i] = lines[i][:-1] + "\t-\n"
            word_idx += 1
        except: continue
    
    # rewrite the file with nouns
    with open(datafile, 'w') as f:
        for i, line in enumerate(lines): f.write(str(lines[i]))
            

In [None]:
add_nouns("../datasets/conll_files_05_3/test.conll",3)

### Mention Class

In [2]:
# is storing this much a waste of space?
# an object can be edited - all multi word mentions, appended to list self.mention
# while reading input file maintain dict {chain idx:{mentionid:Mention}}
class Mention:
    def __init__(self, chain_idx, mention_idx, sentence_idx, word_idx, word, linguistic_head, linguistic_head_idx, nouns, multi_word_mention=False):
        
        # comes from annotated data
        self.chain_idx = chain_idx
        self.mention_idx = mention_idx 

        # calculated when reading data input - required to indentify words uniquely
        # word_idx for multi word mentions is the idx of the first word
        self.sentence_idx = sentence_idx
        self.word_idx = word_idx
        
        #list of multiword mentions as separate words,
        self.words = word
        # toggle when object is edited 
        self.multi_word_mention = multi_word_mention

        # should this be calculated during input or during creation of required format list?
        # what i did -> saved linguistic head, calculate nouns when requied
        self.nouns = nouns # [(word,idx)]
        self.linguistic_head = linguistic_head
        self.linguistic_head_idx = linguistic_head_idx
        

## Reading Input

### Mujadia input

In [4]:
def get_mujadia_input(ip_file):
    chains = {}
    sentence_idx = 1
    word_idx = 1
    with open(ip_file) as f:
        for i,line in enumerate(f):
            # end of sentence
            if line[0] == '|': 
                sentence_idx +=1 
                word_idx = 0
                continue
            
            # new line
            if line[0] == "\n": continue

            # print(line)
            line = line.split()
            # print(line)

            word = line[0]

            # if not a mention, continue
            if line[1] == "_": continue

            chain_idx = line[1].split(":",1)[1]

            # TODO - account for multiple mentions for the same word - check if necessary
            chain_idx = chain_idx.split(",",1)[0]
            mention_idx = line[1].split(":",1)[0].split("%",1)[0]

            unique_idx = str(sentence_idx) + "%" + str(word_idx)

            
            linguistic_head = ""
            linguistic_head_idx = ""
            if line[2] != '_': 
                linguistic_head = line[2].split(":",1)[0]
                linguistic_head_idx = line[2].split(":",1)[1]
            
            # first time encountering a chain - (chain and entity used replacably)
            if chain_idx not in chains.keys():
                chains[chain_idx] = []
            
            # first time seeing mention - create new object
            if not any(mention.mention_idx == mention_idx for mention in chains[chain_idx]):
                mention = Mention(chain_idx, mention_idx,  sentence_idx, word_idx, [word], linguistic_head, linguistic_head_idx, multi_word_mention=False)
                chains[chain_idx].append(mention)

            # multi word mentions - edit already created object
            else:
                mention = next((mention for mention in chains[chain_idx] if mention.mention_idx == mention_idx), None)
                mention.multi_word_mention = True
                mention.words.append(word)
                if mention.linguistic_head == "":
                    mention.linguistic_head = linguistic_head
                    mention.linguistic_head_idx = linguistic_head_idx
            

            word_idx += 1

    # print(len(chains))
    return chains

### Conll Input

In [3]:
# active list is list of all mentions in their unique index format entity_idx%mention_idx that are currently open where mention idx is the number of mentions seen for that entity so far
# global_stack is a dictionary of entity_idx: entity_stack that keeps track of the mentions active by their unique index in a stack as the key value.

def get_conll_input(ip_file, word_annotation_idx, label_idx, verbose=False, get_nouns=False, language="hindi"):

    if get_nouns: add_nouns(ip_file, word_annotation_idx,language)
    
    chains = {}
    active_list = []
    global_stack = {}
    word_idx = 0

    # keeping track of number of mentions for each entity
    entity_mention_count = {}

    #using a unique index for each mention - chain_idx%entity_mention_count[chain_idx] to keep a global stack of which mentions are open

    with open(ip_file) as f:
        for i,line in enumerate(f):
            
            if line == "\n": continue
            if line.split()[0] == '#begin': continue
            if line.split()[0] == '#end': continue

            word_idx += 1

            line = line.split()

            word = line[word_annotation_idx]
            label = line[label_idx]
            is_noun = line[-1]

            # add words to those in the global stack that are currently active
            if label == "-":                 
                for mention_unq_idx in active_list:
                    # print("mention unique idx: ", mention_unq_idx)
                    # print(mention_unq_idx)
                    chain_idx = mention_unq_idx.split("%")[0]
                    mention_idx = mention_unq_idx.split("%")[1]
                    linguistic_head = ""
                    linguistic_head_idx = ""
                    
                    # multi word mentions - edit already created object
                    # print("chain idx ",chain_idx," mention idx ",mention_idx)
                    # print("chains[chain_idx]: " , vars(chains[chain_idx][0]))

                    # return chains

                    mention = next((mention for mention in chains[str(chain_idx)] if mention.mention_idx == int(mention_idx)), None)
                    # print("mention: ",mention)
                    # pprint(vars(mention))
                    mention.multi_word_mention = True
                    mention.words.append(word)
                    if mention.linguistic_head == "":
                        mention.linguistic_head = linguistic_head
                        mention.linguistic_head_idx = linguistic_head_idx
                    
                    if is_noun == "noun": mention.nouns.append(word)

            
                continue
            
            if verbose:
                print("LINE NUMBER ", i+1)
                print("label: ", label)

            # ---------------- SINGLE WORD MENTIONS -------------------
            # STEP 1: look for (0-9) exact matches for single word mentions and remove that string
            # 1a: getting numbers between parantheses - single word mentions
            single_word_entity_idxs = re.findall(r'\(\d+\)', label)
            single_word_entity_idxs = sorted(single_word_entity_idxs, key=len,reverse=True)
            for r in single_word_entity_idxs:
                label = label.replace(r,"")
            for r_idx in range(len(single_word_entity_idxs)):
                single_word_entity_idxs[r_idx] = single_word_entity_idxs[r_idx].replace("(","")
                single_word_entity_idxs[r_idx] = single_word_entity_idxs[r_idx].replace(")","")
            # print("single word entity idxs: ",single_word_entity_idxs)
            # print("label: ", label)

            # 1b: creating new mentions for single word mentions
            for entity_idx in single_word_entity_idxs:
                # print(mention_unq_idx)
                chain_idx = entity_idx

                if entity_idx in entity_mention_count.keys(): entity_mention_count[entity_idx] += 1
                else: entity_mention_count[entity_idx] = 0
                mention_idx = entity_mention_count[entity_idx]

                linguistic_head = ""
                linguistic_head_idx = ""

                nouns = []
                if is_noun == "noun": nouns.append(word)

                # first time encountering a chain - (chain and entity used replacably)
                if chain_idx not in chains.keys():
                    chains[chain_idx] = []
                # print("CREATED MENTION ",chain_idx + "%" + str(mention_idx))
                mention = Mention(chain_idx, mention_idx,  0, word_idx, [word], linguistic_head, linguistic_head_idx, multi_word_mention=False, nouns=nouns)
                chains[chain_idx].append(mention)
            
            # ------------------ MENTIONS STARTING AT THIS WORD ------------
            # STEP 2: look for (0-9 exact matches and remove, push number into live stack
            # 2a: multi word mentions starting at this word
            start_entity_idxs = re.findall(r'\(\d+', label)
            start_entity_idxs = sorted(start_entity_idxs, key=len,reverse=True)
            for r in start_entity_idxs:
                label = label.replace(r,"")
            for r_idx in range(len(start_entity_idxs)):
                start_entity_idxs[r_idx] = start_entity_idxs[r_idx].replace("(","")
                start_entity_idxs[r_idx] = start_entity_idxs[r_idx].replace(")","")
            # print("start entity idxs: ",start_entity_idxs)
            # print("label: ", label)

            # 2b: adding to existing mentions in active list for multi word mentions - currently active mentions, all these mentions have already been created
            for mention_unq_idx in active_list:
                # print("mention unique idx: ", mention_unq_idx)
                # print(mention_unq_idx)
                chain_idx = mention_unq_idx.split("%")[0]
                mention_idx = mention_unq_idx.split("%")[1]
                linguistic_head = ""
                linguistic_head_idx = ""
                    
                # multi word mentions - edit already created object
                # print("chain idx ",chain_idx," mention idx ",mention_idx)
                # print("chains[chain_idx]: " , vars(chains[chain_idx][0]))

                # return chains

                mention = next((mention for mention in chains[str(chain_idx)] if mention.mention_idx == int(mention_idx)), None)
                # print("mention: ",mention)
                # pprint(vars(mention))
                mention.multi_word_mention = True
                mention.words.append(word)
                if mention.linguistic_head == "":
                    mention.linguistic_head = linguistic_head
                    mention.linguistic_head_idx = linguistic_head_idx
                
                if is_noun == "noun": mention.nouns.append(word)

            
            # 2c: adding starting exist to a list of currently active mentions and creating those mentions
            for entity_idx in start_entity_idxs:
                # print(mention_unq_idx)
                chain_idx = entity_idx

                if entity_idx in entity_mention_count.keys(): entity_mention_count[entity_idx] += 1
                else: entity_mention_count[entity_idx] = 0
                mention_idx = entity_mention_count[entity_idx]

                linguistic_head = ""
                linguistic_head_idx = ""
                
                nouns = []
                if is_noun == "noun": nouns.append(word)
                # first time encountering a chain - (chain and entity used replacably)
                if chain_idx not in chains.keys():
                    chains[chain_idx] = []
                # print("CREATED MENTION ",chain_idx + "%" + str(mention_idx))
                mention = Mention(chain_idx, mention_idx,  0, word_idx, [word], linguistic_head, linguistic_head_idx, multi_word_mention=False, nouns=nouns)
                chains[chain_idx].append(mention)

                active_list.append(entity_idx + "%" + str(entity_mention_count[entity_idx]))
                
                if entity_idx not in global_stack.keys(): global_stack[entity_idx] = []
                global_stack[entity_idx].append(entity_idx + "%" + str(entity_mention_count[entity_idx]))


            # ---------------- MENTIONS ENDING AT THIS WORD ----------------------
            # STEP 3: look for 0-9) exact matches and remove, remove from live stack list
            # 3a: get multi word mentions idxs ending at this word
            end_entity_idxs = re.findall(r'\d+\)', label)
            end_entity_idxs = sorted(end_entity_idxs, key=len,reverse=True)
            for r in end_entity_idxs:
                label = label.replace(r,"")
            for r_idx in range(len(end_entity_idxs)):
                end_entity_idxs[r_idx] = end_entity_idxs[r_idx].replace("(","")
                end_entity_idxs[r_idx] = end_entity_idxs[r_idx].replace(")","")
            # print("end entity idxs: ",end_entity_idxs)
            # print("label: ", label)
            

            # 3b: removing idxs from end idx from currently active mentions and updating the global stack
            for entity_idx in end_entity_idxs:
                mention_idx_being_removed = global_stack[entity_idx].pop()
                active_list.remove(mention_idx_being_removed)

            # print("entity_mention_count ", entity_mention_count.items())
            # print("global stack ", global_stack.items())
            # print("active list ", *active_list)    
            # print("\n\n")
    # print(len(chains))
    return chains

## Helper Functions - Format Conversion + Intersection

### Mujadia

In [10]:
# returns an entity which is represented as a list of "sentidx%wordidx" each of which represents a mention
# multiword mentions: "sentidx%word1idx-word2idx-word3-idx" -> update: wrong
# treat multiword mentions all as normal single word mentions by splitting them apart - eg  "sentidx%word1idx-word2idx-word3-idx" ->  ["sentidx%word1idx" "sentidx%-word2idx" "sentidx%-word3idx"]

# key: ["the pretty girl","the lady", "a women", "the teacher", "priya"]
# resp: ["lady", "the", "priya", "pretty"]  (the is from the lady)
# words ->
# b3 recall : 4^2/10 / 10 = 0.16
# muc recall:  7
# binary match -> 
# b3 recall : 0.0
#  7 

# treating all words (even words belonging to multiword mentions as individual mentions)
def get_words(entity):
    entity_f = []
    for mention in entity:
        for i,word in enumerate(mention.words): 
            entity_f.append(str(i+mention.word_idx))
    return entity_f

# have to calculate nouns outside bc length of text can affect how it is calculated
def get_nouns(entity):
    entity_f = []

    for mention in entity:
        if mention.multi_word_mention:
            for (word,idx) in mention.nouns: entity_f.append(mention.sentence_idx + "%" + str(idx + self.word_idx))
        else: entity_f.append(mention.sentence_idx + "%" + str(self.word_idx))

    return entity_f

def get_linguistic_head(entity):
    entity_f = []
    for mention in entity:
        entity_f.append(mention.sentence_idx + "%" + mention.linguistic_head_idx)
    return entity_f

# TODO CLARIFY BIG ASSUMPTION - multi word mentions are all one after the other. edit: no idea what this means
def get_binary_match(entity):
    entity_f = []
    for mention in entity:
        if mention.multi_word_mention:
            word_idx = str(mention.word_idx)
            for i,word in enumerate(mention.words): word_idx += str("-" + str(mention.word_idx + i + 1))
            entity_f.append(word_idx)
        else: entity_f.append(str(mention.word_idx))
    return entity_f

# assumption - all multi word mentions are sequentially together ? word index simple to calculate, only need to store first
# assumption - each mention only has one linguistic head ? save as string : save as list
# assumption - once a mention object is created, no need to edit it (multi word mention, more words added later)

# method 1 - word - each word is a separate link
# method 2 - nouns - each mention is a nouns
# method 3 - linguistic head
# method 4 - each mention is a binary match

# returns an entity - list of mentions in the unique idx format
# example - [1%2,3%4,5%1-2-3-4]
def get_required_format_mujadia(entity, method="binary match"):
    
    if method == "word": return get_words(entity)

    if method == "nouns": return get_nouns(entity)

    if method == "linguistic head": return get_linguistic_head(entity)

    # strict matching
    if method == "binary match": return get_binary_match(entity)

### CoNLL

In [4]:
# returns an entity which is represented as a list of "sentidx%wordidx" each of which represents a mention
# multiword mentions: "sentidx%word1idx-word2idx-word3-idx" -> update: wrong
# treat multiword mentions all as normal single word mentions by splitting them apart - eg  "sentidx%word1idx-word2idx-word3-idx" ->  ["sentidx%word1idx" "sentidx%-word2idx" "sentidx%-word3idx"]

# key: ["the pretty girl","the lady", "a women", "the teacher", "priya"]
# resp: ["lady", "the", "priya", "pretty"]  (the is from the lady)
# words ->
# b3 recall : 4^2/10 / 10 = 0.16
# muc recall:  7
# binary match -> 
# b3 recall : 0.0
#  7 

def get_words(entity):
    entity_f = []
    for mention in entity:
        for i,word in enumerate(mention.words): 
            entity_f.append(str(mention.sentence_idx) + "%" + str(i+mention.word_idx))
    return entity_f

# have to calculate nouns outside bc length of text can affect how it is calculated
def get_nouns(entity):
    entity_f = []
    for mention in entity:
        for i,word in enumerate(mention.nouns): 
            entity_f.append(str(mention.sentence_idx) + "%" + str(i+mention.word_idx))
    return entity_f

def get_linguistic_head(entity):
    entity_f = []
    for mention in entity:
        entity_f.append(mention.sentence_idx + "%" + mention.linguistic_head_idx)
    return entity_f

# TODO CLARIFY BIG ASSUMPTION - multi word mentions are all one after the other
def get_binary_match(entity):
    entity_f = []
    for mention in entity:
        if mention.multi_word_mention:
            word_idx = str(mention.word_idx)
            for i,word in enumerate(mention.words): word_idx += str("-" + str(mention.word_idx + i + 1))
            entity_f.append(str(mention.sentence_idx) + "%" + word_idx)
        else: entity_f.append(str(mention.sentence_idx) + "%" + str(mention.word_idx))
    return entity_f

# assumption - all multi word mentions are sequentially together ? word index simple to calculate, only need to store first
# assumption - each mention only has one linguistic head ? save as string : save as list
# assumption - once a mention object is created, no need to edit it (multi word mention, more words added later)

# method 1 - word - each word is a separate 
# method 2 - nouns - each mention is a nouns
# method 3 - linguistic head
# method 4 - each mention is a binary match

# returns an entity - list of mentions in the unique idx format
# example - [1%2,3%4,5%1-2-3-4]
def get_required_format_conll(entity, method='binary match'):
    if method == "word": return get_words(entity)

    if method == "nouns": return get_nouns(entity)

    if method == "linguistic head": return get_linguistic_head(entity)

    if method == "binary match": return get_binary_match(entity)

In [5]:
# TODO: convert to a class
DATASET = "conll"

def get_required_format(entity, method="binary match"):

    if DATASET == "conll": return get_required_format_conll(entity, method)
    if DATASET == "mujadia": return get_required_format_mujadia(entity, method)

In [6]:
def intersection(key_entity, resp_entity, method="binary match", metric="muc"):

    key_entity = get_required_format(key_entity, method)
    resp_entity = get_required_format(resp_entity, method)

    # print(f"key entity: {key_entity}")
    # print(f"response entity: {resp_entity}")
    
    return list(set(key_entity) & set(resp_entity))

## Metrics

### B Cubed

In [7]:
# entities -> list of entities where each entity is a list of mentions
def b_cubed(key_entities, resp_entities, score = "recall", method="binary match"):

    entities1 = key_entities
    entities2 = resp_entities

    if score == "precision":
        entities1 = resp_entities
        entities2 = key_entities

    numerator = 0
    denominator = 0

    for entity1 in entities1:

        # differs based on method used
        entity1_size = 0

        if method == "binary match": 
            entity1_size = len(entity1)
        if method == "word": 
            for mention in entity1: entity1_size += len(mention.words)

        denominator += entity1_size

        for entity2 in entities2:
            numerator += ( pow(len(intersection(entity1,entity2,method=method)),2) / entity1_size )

    try:
        recall = numerator / denominator
    except:
        # print("denominator=0")
        return -1

    # print(recall)

    if recall > 1: print("GREATER THAN 1")
    # else: print(recall)
    
    return recall

### MUC

In [8]:
# condition: a mention belongs only to one entity (or else using variable common to count common mentions across sets won't be useful)
# update: implementation fixed to account for mentions belonging to more than one entity
'''
algorithm - 
given an entity set (eg key) for which partition is created wrt to opposite set (eg response) -> # of partitions - 
= number of opposite sets it has at least one element in common with + number of mentions which are not in common with any opp sets                  

'''

# TODO: add case where # of key entities = 0
def muc(key_entities, resp_entities, score = "recall", method="binary match"):

    numerator = 0
    denominator = 0

    entities1 = key_entities
    entities2 = resp_entities

    if score == "precision":
        entities1 = resp_entities
        entities2 = key_entities

    for entity1 in entities1:

        # differs based on method used
        entity1_size = 0

        if method == "binary match": 
            entity1_size = len(entity1)
        if method == "word": 
            for mention in entity1: entity1_size += len(mention.words)
        
        numerator += (entity1_size - partition(entity1, entities2, method=method))
        
        denominator += (entity1_size - 1)
    
    # print (numerator / denominator)
    if denominator == 0: return 0
    return numerator / denominator

### CEAFe
best scoring alignment: first entity where intersection is maximum across all

In [9]:
def phi(entity1, entity2, method="binary match"):

    entity1_f = get_required_format(entity1, method=method)
    entity2_f = get_required_format(entity2, method=method)

    return (2 * len(intersection(entity1,entity2, method=method)) ) / ( len(entity1_f) + len(entity2_f) )


In [10]:
def ceafe(key_entities, resp_entities, score = "recall", method="binary match"):

    numerator = 0
    denominator = 0

    # num of elts common across two entity sets
    alignment_score = 0
    phi_val = 0

    entities1 = key_entities
    entities2 = resp_entities

    if score == "precision":
        entities1 = resp_entities
        entities2 = key_entities
    
    for entity1 in entities1:

        for entity2 in entities2:

            if alignment_score < len(intersection(entity1, entity2, method=method)):
                alignment_score = len(intersection(entity1, entity2, method=method))
                phi_val = phi(entity1,entity2,method=method)
        
        numerator += phi_val
        denominator += 1

    # print(numerator/denominator)
    return numerator/denominator

### LEA

In [11]:
def link(n):
    return n*(n-1)/2

In [17]:
# assuming key_entities is list of mentions
# recall -> entities1 : key_entities, entities1 : resp_entities
# precision -> entities1 : resp_entities, entities1 : key_entities
def lea(entities1, entities2, method="binary match"):

    numerator = 0 
    denominator = 0

    # print("\n\nRUNNING THROUGH ALL PAIRS OF ENTITIES....")    
    
    for entity1 in entities1:
        
        # print("\n**************************\n")

        # print(f"(key) entity1 idx: {entity1[0].chain_idx}")

        entity1_mentions = get_required_format(entity1, method=method)
        # print(f"(key) entity1_mentions: {entity1_mentions}")

        importance = len(entity1_mentions) #importance = size of entity (number of mentions)
        resolution_score = 0

        # print("\nCOMPARING THE KEY WITH ALL OTHER ENTITIES...\n")
        
        for entity2 in entities2:

            entity2_mentions = get_required_format(entity2, method=method)

            # print(f"(response) entity2 idx: {entity2[0].chain_idx}")
            # print(f"(response) entity2_mentions: {entity2_mentions}")

            # link based, therefore they intersect only if there are at least two mentions in common (a link is present)
            common_mentions = intersection(entity1, entity2, method=method)
            # print(f"common mentions: {common_mentions}")

            # for singleton mentions, link(ki intersection ri) = 1 only when ki = ri = 1
            if len(common_mentions) == 1:
                if ((len(entity1_mentions) == 1) and (len(entity2_mentions) == 1)): resolution_score += 1
            
            # at least 2 common mentions means there is at least one link
            if len(common_mentions) >= 2:
                # denominator will never be 0 since len(entity1) >= 2 if len(common_mentions) >= 2
                # print("common mentions (link) found...")
                # print(f"entity1: {entity1_mentions}")
                # print(f"entity2 mentions: {entity2_mentions}")
                # print(f"len common mentions: {len(common_mentions)}\na = (link(len(common_mentions))): {link(len(common_mentions))}")
                # print(f"len entity1_len: {len(entity1_mentions)}\nb = (link(len(entity1_mentions))): {link(len(entity1_mentions))}")
                val = (link(len(common_mentions)))/(link(len(entity1_mentions)))
                # print(f"VALUE BEING ADDED TO RESOLUTION SCORE (a/b) = {val}")
                
                if(val > 1): 
                    # print("ERROR: value is greater than one")
                    return -1
                
                resolution_score += (link(len(common_mentions)))/(link(len(entity1_mentions)))
                # print(f"RESOLUTION SCORE = {resolution_score}")

            # print("\n")
        numerator += (importance * resolution_score)
        # print(f"IMPORTANCE: {importance}")
        # print(f"RESOLUTION SCORE: {resolution_score}")
        # print(f"NUMERATOR VALUE ADDED DURING THIS ITERATION: {(importance * resolution_score)}")
        # print(f"NUMERATOR AFTER THIS ITERATION: {numerator}")

        # if numerator is 0, so is denominator - avoiding divide by 0 error
        if numerator == 0: continue
        denominator += len(entity1_mentions)
        # if numerator/denominator > 1:
            # print(f"ERROR: numerator = {numerator} is greater than denominator = {denominator}")
        #     for mention in entity1:
        #         pprint(vars(mention))
        #     return -1
        
    
    # print("WENT THROUGH ALL PAIRS OF ENTITIES....\n\n")
    if denominator == 0: return 0
    return (numerator/denominator)

### Evaluating Metrics

In [13]:
def evaluate_metric(metric, entities1, entities2, method="binary match"):
    if metric == "muc":
        return muc(entities1, entities2, method=method)
    if metric == "bcub":
        return b_cubed(entities1, entities2, method=method)
    if metric == "ceafe":
        return ceafe(entities1, entities2, method=method)
    if metric == "lea":
        return lea(entities1, entities2, method=method)

In [14]:
def get_scores(entities1, entities2, metric, method, verbose=False):
    if verbose: print("**********",metric,"**********")
    # try: 
    recall = round(evaluate_metric(metric, entities1, entities2, method=method),2)
    if verbose: print("recall: ",recall)
    # except Exception as error: 
        # print(f"ERROR IN RECALL: {error}")
    try: 
        precision = round(evaluate_metric(metric, entities2, entities1, method=method),2)
        if verbose: print("precision: ",precision)
    except Exception as error: 
         print(f"ERROR IN PRECISION: {error}")
    try: 
        f1 = round((2 * recall * precision) / (recall + precision),2)
        if verbose: print("f1: ",f1)
    except: f1 = 0

    return {'recall':recall, 'precision':precision, 'f1':f1}

In [15]:
def calculate_metrics(key_chains, response_chains, metric='all', method="binary match",verbose=False):

    key_entities = [key_chains[key] for key in key_chains.keys()]
    response_entities = [response_chains[key] for key in response_chains.keys()]

    # print("\n\ncalculate_metrics entities: ", key_entities)

    scores = {}

    if verbose:
        print("KEYS:")
        for entity in key_entities:
            for mention in entity:
                print(vars(mention),"\n")

        print("\n\nRESPONSE:")
        for entity in response_entities:
            print(vars(entity[0]),"\n")

    if metric == "all":

        scores['muc'] = get_scores(key_entities, response_entities, "muc",method=method, verbose=verbose)
        scores['bcubed'] = get_scores(key_entities, response_entities, "bcub",method=method, verbose=verbose)
        scores['ceafe'] = get_scores(key_entities, response_entities, "ceafe",method=method, verbose=verbose)
        scores['lea'] = get_scores(key_entities, response_entities, "lea",method=method, verbose=verbose)

    else: 
        if metric == 'b_cubed': scores["bcubed"] = get_scores(key_entities, response_entities, "bcub",method=method, verbose=verbose)
        if metric == 'muc': scores["muc"] = get_scores(key_entities, response_entities, "muc",method=method, verbose=verbose)
        if metric == 'ceafe': scores["ceafe"] = get_scores(key_entities, response_entities, "ceafe",method=method, verbose=verbose)
        if metric == 'lea': scores["lea"] = get_scores(key_entities, response_entities, "lea",method=method, verbose=verbose)

    return scores

### Running Evaluator

Running for given pair of ground truth and predicted...

In [None]:

# truth = get_conll_input("../datasets/coref_model_outputs/235/true/235_final_test_data_true_17.conll",3,16)
# pred = get_conll_input("../datasets/coref_model_outputs/235/true/235_final_test_data_true_17.conll",3,16)
truth = get_conll_input("../datasets/all_files/true_GT.conll",3,16)
pred = get_conll_input("../datasets/all_files/247_ab11_final_test_data_pred.conll",2,3)
# pred = get_conll_input("../datasets/conll_files_05_3/pred3_wlcoref.conll",3,11)
# scores = calculate_metrics(truth, pred, method="nouns", metric="lea",verbose=True)
scores = calculate_metrics(truth, pred, metric="lea")
scores

Running for a set of files present in a directory at the same time....

In [16]:
# RUNNING EVALUATOR FOR MULTIPLE FILES, AVERAGING ACROSS THEM
import os
import pandas as pd
def get_averaged_scores(gt_dirpath, pred_dirpath, pred_word_idx, pred_annotation_idx):
    scores = {}
    for file_name in os.listdir(gt_dirpath):
        # print("\n\n",file_name)
        truth = get_conll_input(gt_dirpath + file_name,3,16)
        # file_name = file_name.replace("true","pred")
        pred = get_conll_input(pred_dirpath + file_name,pred_word_idx,pred_annotation_idx)
        score = calculate_metrics(truth, pred, metric="lea")
        scores[file_name] = score['lea']
    df = pd.DataFrame(scores)
    df = df.transpose()
    display(df.mean(axis=0))
    return df


In [28]:
df = get_averaged_scores("../datasets/all_files/filewise_split/true_GT/", "../datasets/all_files/filewise_split/247_ab11_final_test_data_pred/", 2, 3)

recall       0.019286
precision    0.170000
f1           0.030357
dtype: float64

In [29]:
df = get_averaged_scores("../datasets/all_files/filewise_split/true2_GT/", "../datasets/all_files/filewise_split/247_ab12_final_test_data_pred/",2, 3)

recall       0.018214
precision    0.241071
f1           0.031071
dtype: float64

In [30]:
df = get_averaged_scores("../datasets/all_files/filewise_split/true3_GT/", "../datasets/all_files/filewise_split/247_ab13_final_test_data_pred/", 2, 3)

recall       0.004643
precision    0.166786
f1           0.009286
dtype: float64

## Taking Input from model output 
### model output in required format -  model-mishra-format

In [100]:
len(os.listdir("../datasets/model-mishra-format"))

1730

In [133]:
import os
idx=-1
a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r = 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0

scores = {'recall':{'binary match':{'bcubed':0,'muc':0,'ceafe':0},'words':{'bcubed':0,'muc':0,'ceafe':0}}, 'precision':{'binary match':{'bcubed':0,'muc':0,'ceafe':0},'words':{'bcubed':0,'muc':0,'ceafe':0}}}

for file_idx,filename in enumerate(os.listdir("../datasets/model-mishra-format")):


    if "pred" in filename: continue


    else:
        truth = get_input("../datasets/model-mishra-format/"+filename)
        pred = get_input("../datasets/model-mishra-format/"+filename.replace('true','pred'))
    
        # only calculate metrics on non empty chains 


    # print("calculating ", file_idx)
    idx += 1


    try:

        a = calculate_metrics(truth, pred, score='recall',metric='b_cubed', method="binary match")
        b = calculate_metrics(truth, pred, score='precicision',metric='b_cubed', method="binary match")

        d = calculate_metrics(truth, pred, score='recall',metric='muc', method="binary match")
        f = calculate_metrics(truth, pred, score='precicision',metric='muc', method="binary match")

        g = calculate_metrics(truth, pred, score='recall',metric='ceafe', method="binary match")
        h = calculate_metrics(truth, pred, score='precicision',metric='ceafe', method="binary match")

        j = calculate_metrics(truth, pred, score='recall',metric='b_cubed', method="word")
        k = calculate_metrics(truth, pred, score='precicision',metric='b_cubed', method="word")

        m = calculate_metrics(truth, pred, score='recall',metric='muc', method="word")
        n = calculate_metrics(truth, pred, score='precicision',metric='muc', method="word")

        p = calculate_metrics(truth, pred, score='recall',metric='ceafe', method="word")
        q = calculate_metrics(truth, pred, score='precicision',metric='ceafe', method="word")

        scores['recall']['binary match']['bcubed'] += a
        scores['precision']['binary match']['bcubed'] += b
        scores['recall']['binary match']['muc'] += d
        scores['precision']['binary match']['muc'] += f
        scores['recall']['binary match']['ceafe'] += g
        scores['precision']['binary match']['ceafe'] += h

        scores['recall']['words']['bcubed'] += j
        scores['precision']['words']['bcubed'] += k
        scores['recall']['words']['muc'] += m
        scores['precision']['words']['muc'] += n
        scores['recall']['words']['ceafe'] += p
        scores['precision']['words']['ceafe'] += q

    
    except:
        idx -= 1

for k in scores.keys():
    for k2 in scores[k].keys():
        for k3 in scores[k][k2].keys():
            scores[k][k2][k3] /= (idx+1)


print(scores)


# print("METRICS with ", idx+1, " files")
# print(a/(idx+1),b/(idx+1),d/(idx+1),f/(idx+1),g/(idx+1),h/(idx+1),j/(idx+1),k/(idx+1),m/(idx+1),n/(idx+1),p/(idx+1),q/(idx+1))

{'recall': {'binary match': {'bcubed': 1.0, 'muc': 0.11885334770243315, 'ceafe': 0.5875895210527876}, 'words': {'bcubed': 1.0, 'muc': 0.39173646879309965, 'ceafe': 0.7460635287551103}}, 'precision': {'binary match': {'bcubed': 1.0, 'muc': 0.15692453718368354, 'ceafe': 0.5732044249572937}, 'words': {'bcubed': 1.0, 'muc': 0.5558751402808626, 'ceafe': 0.7454180649218175}}}
