In [1]:
import os
from tqdm import tqdm
import json
import jsonlines
import pandas as pd
from collections import Counter
import random
import re
import ast

In [2]:
valid_relations = ("Kill", "Located_In", "Live_In", "Work_For", "OrgBased_In")

In [4]:
df = pd.read_csv("conll_flan_explanations_generated.csv")
# df = df.apply(lambda x: x.astype(str).str.lower())
df.head()

Unnamed: 0,text,gold_labels,generated
0,"John Wilkes Booth , who assassinated President...","[['John Wilkes Booth', 'Kill', 'President Linc...","[['John Wilkes Booth', 'Kill', 'President Linc..."
1,The opera company performed at the Palace of F...,"[['Palace of Fine Arts', 'Located_In', 'San Fr...","[['Palace of Fine Arts', 'OrgBased_In', 'San F..."
2,"In the field of mechanics , Wang Ziqiang at th...","[['Wang Ziqiang', 'Work_For', 'Institute of Me...","[['Wang Ziqiang', 'Work_For', 'Institute of Me..."
3,"Sun Hung Kai Properties , a Hong Kong construc...","[['Sun Hung Kai Properties', 'OrgBased_In', 'H...","[['Sun Hung Kai Properties', 'OrgBased_In', 'H..."
4,"Marie Magdefrau Ferraro , 50 , of Bethany , Co...","[['Marie Magdefrau Ferraro', 'Live_In', 'Betha...","[['Marie Magdefrau Ferraro', 'Live_In', 'Betha..."


In [5]:
print ("TOTAL ROWS: ", len(df))

text = []
gold = []
unique_types = set()
generated = []

c = Counter()

errors = 0
for ix, row in df.iterrows():
    try:
        t = ast.literal_eval(row["gold_labels"])
        g = ast.literal_eval(row["generated"])
        for r in g:
            c[r[1]] += 1
            if r[1] not in unique_types:
                unique_types.add(r[1])
            if r[1] not in valid_relations:
                errors += 1
                raise Exception("Invalid relation")
        print (row["text"])
        print ("GOLD: ", t)
        print ("GENERATED", g)
        text.append(row["text"])
        gold.append(t)
        generated.append(g)
    except:
        errors += 1
        print ("ERROR SOMEWHERE")
        print (row["text"])
        print ("GOLD: ", row["gold_labels"])
        print ("GENERATED", row["generated"])
        df.drop(ix, inplace=True)
    print ("\n-----------------------\n")

df = pd.DataFrame({"text": text, "gold": gold, "generated": generated})

TOTAL ROWS:  231
John Wilkes Booth , who assassinated President Lincoln , was an actor .
GOLD:  [['John Wilkes Booth', 'Kill', 'President Lincoln']]
GENERATED [['John Wilkes Booth', 'Kill', 'President Lincoln']]

-----------------------

ERROR SOMEWHERE
The opera company performed at the Palace of Fine Arts , in San Francisco , on June 30 and July 1-2 , said Kevin O 'Brien , a spokesman for the theater.
GOLD:  [['Palace of Fine Arts', 'Located_In', 'San Francisco']]
GENERATED [['Palace of Fine Arts', 'OrgBased_In', 'San Francisco'], ['Kevin O'Brien', 'Work_For', 'the Palace of Fine Arts'], ['Kevin O'Brien', 'Live_In', 'San Francisco']]

-----------------------

In the field of mechanics , Wang Ziqiang at the Institute of Mechanics has made considerable headway in the area of elastoplastic crack mechanics .
GOLD:  [['Wang Ziqiang', 'Work_For', 'Institute of Mechanics']]
GENERATED [['Wang Ziqiang', 'Work_For', 'Institute of Mechanics']]

-----------------------

Sun Hung Kai Properties ,

In [6]:
df.shape

(226, 3)

In [7]:
len(unique_types)

6

In [8]:
c

Counter({'Kill': 40,
         'Work_For': 76,
         'OrgBased_In': 63,
         'Live_In': 83,
         'Located_In': 59,
         'Creator': 1})

In [9]:
tp = ()

input = []
relations = []
prefix = []
gold_relations = []

count = 0
for row in df.iterrows():
    curr_tp = ()
    for t_triplet in row[1]["gold"]:
        if t_triplet in row[1]["generated"]:
            tp += (t_triplet,)
            curr_tp += (t_triplet,)
        else:
            for g_triplet in row[1]["generated"]:
                if ((str(t_triplet[0]) in str(g_triplet[0])) or (str(g_triplet[0]) in str(t_triplet[0]) )) and (( str(t_triplet[2]) in str(g_triplet[2])) or (str(g_triplet[2]) in str(t_triplet[2]) )) and ( str(t_triplet[1]) == str(g_triplet[1]) ):
                    tp += (g_triplet,)
                    curr_tp += (g_triplet,)
                    # print (row[1]["text"] + "\n")
                    # print ("TRUE: " + str(t_triplet))
                    # print ("PARTIAL TP GENERATED: " + str(g_triplet))
                    # print ("-----------------------------------------------")
    if len(curr_tp) > 0:
        count += 1
        print (row[1]["text"] + "\n")
        input.append(row[1]["text"])
        relations.append(list(curr_tp))
        prefix.append("CONLL04")
        gold_relations.append(row[1]["gold"])
        print ("TRUE (SET): ", row[1]["gold"])
        print ("TP (SET): ", list(curr_tp))
        print ("----------------")

print ("TOTAL TRUE RELATIONS: ", count)

John Wilkes Booth , who assassinated President Lincoln , was an actor .

TRUE (SET):  [['John Wilkes Booth', 'Kill', 'President Lincoln']]
TP (SET):  [['John Wilkes Booth', 'Kill', 'President Lincoln']]
----------------
In the field of mechanics , Wang Ziqiang at the Institute of Mechanics has made considerable headway in the area of elastoplastic crack mechanics .

TRUE (SET):  [['Wang Ziqiang', 'Work_For', 'Institute of Mechanics']]
TP (SET):  [['Wang Ziqiang', 'Work_For', 'Institute of Mechanics']]
----------------
Sun Hung Kai Properties , a Hong Kong construction firm with a 27 percent share ;

TRUE (SET):  [['Sun Hung Kai Properties', 'OrgBased_In', 'Hong Kong']]
TP (SET):  [['Sun Hung Kai Properties', 'OrgBased_In', 'Hong Kong']]
----------------
Marie Magdefrau Ferraro , 50 , of Bethany , Conn. , was shot to death Thursday when two bandits armed with assault rifles emerged from nearby bushes and began firing at a van carrying a Connecticut Audubon Society wildlife wild tour gro

In [10]:
fp = ()

fp_list = []
text = []

for row in df.iterrows():
    for g_triplet in row[1]["generated"]:
        flag = True
        if g_triplet in row[1]["gold"]:
            continue
        else:
            for t_triplet in row[1]["gold"]:
                if ((str(t_triplet[0]) in str(g_triplet[0]) ) or (str(g_triplet[0]) in str(t_triplet[0]) )) and ((str(t_triplet[2]) in str(g_triplet[2])) or (str(g_triplet[2]) in str(t_triplet[2]))) and (str(t_triplet[1]) == str(g_triplet[1])):
                    flag = False
            if flag:
                text.append(row[1]["text"])
                fp_list.append(g_triplet)
                fp += (g_triplet,)
                print (row[1]["text"] + "\n")
                print ("TRUE (SET): ", row[1]["gold"])
                print ("FALSE POSITIVE GENERATED: " + str(g_triplet))
                print ("-----------------------------------------------")

Lightning ignited a firestorm that charred 10 , 000 acres in central Idaho , including much of the small resort town of Lowman , about 70 miles northeast of Boise.

TRUE (SET):  [['Lowman', 'Located_In', 'Idaho']]
FALSE POSITIVE GENERATED: ['Lotman', 'Located_In', 'Idaho']
-----------------------------------------------
The three reactors , all at the Savannah River Plant , in Aiken , S.C. , have been shut down since last April undergoing changes to make them safer .

TRUE (SET):  [['Savannah River Plant', 'Located_In', 'Aiken'], ['Savannah River Plant', 'Located_In', 'S.C.'], ['Aiken', 'Located_In', 'S.C.']]
FALSE POSITIVE GENERATED: ['Savana River Plant', 'OrgBased_In', 'Aiken']
-----------------------------------------------
The three reactors , all at the Savannah River Plant , in Aiken , S.C. , have been shut down since last April undergoing changes to make them safer .

TRUE (SET):  [['Savannah River Plant', 'Located_In', 'Aiken'], ['Savannah River Plant', 'Located_In', 'S.C.'], 

In [12]:
print (len(fp_list))
print ("TP: ", len(tp))
precision = len(tp) / (len(tp) + len(fp))
print ("PRECISION: ", precision)

123
TP:  201
PRECISION:  0.6203703703703703


In [13]:
fn = ()
fn_list = []
text = []
gen_list = []
for row in df.iterrows():
    for t_triplet in row[1]["gold"]:
        flag = True
        if t_triplet in row[1]["generated"]:
            continue
        else:
            for g_triplet in row[1]["generated"]:
                if ((str(t_triplet[0]) in str(g_triplet[0])) or (str(g_triplet[0]) in str(t_triplet[0]))) and ((str(t_triplet[2]) in str(g_triplet[2])) or (str(g_triplet[2]) in str(t_triplet[2]))) and (str(t_triplet[1]) == str(g_triplet[1])):
                    flag = False
            if flag:
                text.append(row[1]["text"])
                fn_list.append(t_triplet)
                gen_list.append(str(row[1]["generated"]))
                fn += (t_triplet,)
                print (row[1]["text"] + "\n")
                print ("TRUE (SET): ", row[1]["gold"])
                print ("GENERATED: " + str(row[1]["generated"]))
                print ("FALSE NEGATIVE: " + str(t_triplet))
                print ("-----------------------------------------------")

Lightning ignited a firestorm that charred 10 , 000 acres in central Idaho , including much of the small resort town of Lowman , about 70 miles northeast of Boise.

TRUE (SET):  [['Lowman', 'Located_In', 'Idaho']]
GENERATED: [['Lotman', 'Located_In', 'Idaho']]
FALSE NEGATIVE: ['Lowman', 'Located_In', 'Idaho']
-----------------------------------------------
The three reactors , all at the Savannah River Plant , in Aiken , S.C. , have been shut down since last April undergoing changes to make them safer .

TRUE (SET):  [['Savannah River Plant', 'Located_In', 'Aiken'], ['Savannah River Plant', 'Located_In', 'S.C.'], ['Aiken', 'Located_In', 'S.C.']]
GENERATED: [['Savana River Plant', 'OrgBased_In', 'Aiken'], ['Savana River Plant', 'OrgBased_In', 'S.C.'], ['Aiken', 'Located_In', 'S.C.']]
FALSE NEGATIVE: ['Savannah River Plant', 'Located_In', 'Aiken']
-----------------------------------------------
The three reactors , all at the Savannah River Plant , in Aiken , S.C. , have been shut down s

In [14]:
print ("FN: ", len(fn))
recall = len(tp) / (len(tp) + len(fn))
print ("RECALL: ", recall)
print ("F1 SCORE: ", 2 * (precision * recall) / (precision + recall))

FN:  136
RECALL:  0.5964391691394659
F1 SCORE:  0.6081694402420574
