In [1]:
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

In [2]:
import pandas

train = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\train-ucca_paths_v0.0.5.csv')
test = pandas.read_csv(r'C:\Users\JYellin\re_1\tacred\results\general\test-ucca_paths_v0.0.5.csv')

# no point to consider 'no_relation' sentences when considering recall ,,,
train = train[ train['relation'] != 'no_relation']
test = test[ test['relation'] != 'no_relation']


In [3]:
def transform_row(r):
    return '{entity1} {org_path} {entity2}'.format(entity1=r.type1[0:3], org_path=r.path, entity2=r.type2[0:3])

train['path'] = train.apply(transform_row, axis=1)
test['path'] = test.apply(transform_row, axis=1)

In [4]:
test_total_by_r_and_p = test.groupby(['relation','path'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['relation','total'], ascending=[False,False]).reset_index(drop=True)
train_total_by_r_and_p = train.groupby(['relation','path'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['relation','total'], ascending=[True,False]).reset_index(drop=True)
train_total_by_p = train.groupby(['path'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['total'], ascending=[False]).reset_index(drop=True)


In [5]:
def label_count_in_train_across_relations(test_row):
    path = test_row['path']

    train_row = train_total_by_p.loc[train_total_by_p['path'] == path]
    if train_row.empty:
        return 0
    else:
        compare  = train_row.values[0][1]
        return compare

def label_count_in_train_by_relation(test_row):
    relation = test_row['relation']
    path = test_row['path']

    train_row = train_total_by_r_and_p.loc[
        (train_total_by_r_and_p['relation'] == relation) & 
        (train_total_by_r_and_p['path'] == path)]
    if train_row.empty:
        return 0
    else:
        compare  = train_row.values[0][2]
        return compare

test_total_by_r_and_p['total-for-relation'] = test_total_by_r_and_p.apply(lambda row : label_count_in_train_by_relation(row), axis=1)
test_total_by_r_and_p['total-across-relations'] = test_total_by_r_and_p.apply(lambda row : label_count_in_train_across_relations(row), axis=1)

In [6]:
stats = []
relations = train_total_by_r_and_p['relation'].unique().tolist()
for relation in relations:
    
    df = test_total_by_r_and_p[test_total_by_r_and_p['relation']==relation]
    
    matched = sum(df[df['total-for-relation']>0]['total'].to_list())
    total = sum(df['total'].to_list())
    
    stats.append( (relation, round(matched/total,4) ) )

    
overal_matched = sum(test_total_by_r_and_p[test_total_by_r_and_p['total-for-relation']>0]['total'].to_list())
overal_total = sum(test_total_by_r_and_p['total'].to_list())

stats.insert(0, ('overall', round(overal_matched/overal_total,4) ) )
stats = sorted(stats, key=lambda x: x[1], reverse=True)

In [7]:
import plotly.graph_objs as go


overall_index = next(i for (i, stat) in enumerate(stats) if stat[0] == 'overall')

colors = ['rgb(58,200,225)'] * len(stats)
colors[overall_index] = 'rgb(181,59,89)'


trace = go.Bar(
    x=[stat[0] for stat in stats],
    y=[stat[1] for stat in stats],
    name='Recall',
    width=0.5,
    marker=dict(
        color=colors,
        line=dict(
            color='rgb(8,48,107)',
            width=0.5),
        ),
    opacity=0.6
)

data = [trace]
layout = go.Layout(
    title='UCCA Path with Entity Type - Recall',
    xaxis_title="Relation",
    yaxis_title="Recall",        
    barmode='overlay',
    width=1200,
    xaxis = go.layout.XAxis(
        tickangle = 45,
        automargin = True
        
    )
    
        
)

fig = go.Figure(data=data, layout=layout)
iplot(fig)

In [8]:
stats

[('org:top_members/employees', 0.8863),
 ('org:alternate_names', 0.8804),
 ('per:title', 0.8448),
 ('per:age', 0.8333),
 ('org:founded', 0.8056),
 ('per:date_of_birth', 0.7778),
 ('per:employee_of', 0.7731),
 ('org:city_of_headquarters', 0.7531),
 ('org:number_of_employees/members', 0.7368),
 ('per:origin', 0.6894),
 ('overall', 0.6833),
 ('per:stateorprovinces_of_residence', 0.6625),
 ('per:cities_of_residence', 0.6489),
 ('per:countries_of_residence', 0.6486),
 ('org:founded_by', 0.6324),
 ('per:date_of_death', 0.5926),
 ('org:parents', 0.5902),
 ('per:siblings', 0.5818),
 ('org:stateorprovince_of_headquarters', 0.58),
 ('org:subsidiaries', 0.5682),
 ('per:other_family', 0.5667),
 ('per:spouse', 0.5606),
 ('org:country_of_headquarters', 0.5234),
 ('per:alternate_names', 0.5),
 ('per:stateorprovince_of_birth', 0.5),
 ('per:schools_attended', 0.4667),
 ('per:parents', 0.4432),
 ('per:religion', 0.4043),
 ('org:website', 0.4),
 ('per:city_of_birth', 0.4),
 ('per:country_of_birth', 0.4),