In [1]:
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

In [2]:
import pandas

# Note the 'set_index' call !! 
# This is necessary for filtering out sentences with incorrect NER annotations:
train = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\train-ucca_paths_v0.0.5.csv').set_index(['id'])
test = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\test-ucca_paths_v0.0.5.csv').set_index(['id'])

In [3]:
train_ner = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\train-ud_paths_with_ner_v0.0.6.csv').set_index(['id'])
test_ner = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\test-ud_paths_with_ner_v0.0.6.csv').set_index(['id'])

train_ids_with_proper_ner = train_ner[(train_ner['type1']==train_ner['type1_corenlp']) & (train_ner['type2']==train_ner['type2_corenlp'])].drop(['docid', 'tokens', 'relation', 'path', 'lemmas_on_path', 'type1', 'type2', 'ent1_head', 'ent2_head', 'type1_corenlp', 'type2_corenlp'], axis=1)
test_ids_with_proper_ner = test_ner[(test_ner['type1']==test_ner['type1_corenlp']) & (test_ner['type2']==test_ner['type2_corenlp'])].drop(['docid', 'tokens', 'relation', 'path', 'lemmas_on_path', 'type1', 'type2', 'ent1_head', 'ent2_head', 'type1_corenlp', 'type2_corenlp'], axis=1)

In [4]:
print('Train shape ', train.shape)
print('Test shape ', test.shape)

Train shape  (67617, 10)
Test shape  (15417, 10)


In [5]:
train = train_ids_with_proper_ner.join(train).dropna(axis=0).reset_index()
test = test_ids_with_proper_ner.join(test).dropna(axis=0).reset_index()

In [6]:
print('Train shape ', train.shape)
print('Test shape ', test.shape)

Train shape  (56124, 11)
Test shape  (12763, 11)


In [7]:
def transform_row(r):
    return '{entity1} {org_path} {entity2}'.format(entity1=r.type1[0:3], org_path=r.path, entity2=r.type2[0:3])

train['path'] = train.apply(transform_row, axis=1)
test['path'] = test.apply(transform_row, axis=1)

In [8]:
relation_to_paths = train.groupby(['relation'])['path'].unique().to_dict()

In [9]:
def label_sentence_recall(test_row, relation):
    path = test_row['path']
    true_relation = test_row['relation']
    
    if path in relation_to_paths[relation]:
        if true_relation == relation:
            return 'TP'
        else:
            return 'FP'
    else:
        return 'N'

In [10]:
for relation in filter(lambda relation: relation != 'no_relation', relation_to_paths.keys()):
    test[relation] = test.apply(lambda row : label_sentence_recall(row, relation), axis=1)    

In [11]:
stats = []
overall_tp = 0
overall_fp = 0

for relation in filter(lambda relation: relation != 'no_relation', relation_to_paths.keys()):

    counts = test[relation].value_counts().to_dict()
        
    tp = counts.get('TP', 0)
    fp = counts.get('FP', 0)
    precision = tp / (tp+fp)

    stats.append( (relation, round(precision,4)) )
    
    overall_tp += tp
    overall_fp += fp
    
overall_precision = overall_tp / (overall_tp+overall_fp)

    

stats.append( ('overall', round(overall_precision,4)) )
stats = sorted(stats, key=lambda x: x[1], reverse=True)

In [12]:
import plotly.graph_objs as go

overall_index = next(i for (i, (relation, score)) in enumerate(stats) if relation == 'overall')

colors = ['rgb(58,200,225)'] * len(stats)
colors[overall_index] = 'rgb(181,59,89)'


trace1 = go.Bar(
    x=[stat[0] for stat in stats],
    y=[stat[1] for stat in stats],
    name='Precision',
    width=0.5,
    marker=dict(
        color=colors,
        line=dict(
            color='rgb(8,48,107)',
            width=0.5),
        ),
    opacity=0.6
)

data = [trace1]
layout = go.Layout(
    title='UCCA Path with Vetted Entity Type - Precision',
    xaxis_title="Relation",
    yaxis_title="Precision ",        
    barmode='overlay',
    width=1200,
    xaxis = go.layout.XAxis(
        tickangle = 45,
        automargin = True
        
    )
    
        
)

fig = go.Figure(data=data, layout=layout)
iplot(fig)

In [13]:
stats

[('org:website', 0.625),
 ('per:charges', 0.5938),
 ('per:title', 0.5923),
 ('org:city_of_headquarters', 0.5),
 ('org:stateorprovince_of_headquarters', 0.4848),
 ('org:top_members/employees', 0.4417),
 ('per:religion', 0.4333),
 ('per:stateorprovinces_of_residence', 0.427),
 ('per:cities_of_residence', 0.4219),
 ('org:country_of_headquarters', 0.398),
 ('per:employee_of', 0.3725),
 ('per:age', 0.3634),
 ('org:alternate_names', 0.325),
 ('per:cause_of_death', 0.3125),
 ('per:origin', 0.266),
 ('per:countries_of_residence', 0.2578),
 ('org:founded', 0.1353),
 ('per:stateorprovince_of_birth', 0.1333),
 ('overall', 0.1195),
 ('org:number_of_employees/members', 0.1148),
 ('org:parents', 0.0959),
 ('org:founded_by', 0.095),
 ('per:date_of_death', 0.0846),
 ('per:city_of_death', 0.0684),
 ('org:subsidiaries', 0.0682),
 ('per:schools_attended', 0.0664),
 ('per:parents', 0.0339),
 ('per:country_of_birth', 0.0282),
 ('per:date_of_birth', 0.0277),
 ('per:siblings', 0.0249),
 ('per:city_of_birth',