In [1]:
from _plotly_future_ import v4_subplots
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from plotly.subplots import make_subplots
import plotly.graph_objs as go

init_notebook_mode(connected=True)

In [2]:
import pandas

train = pandas.read_json(r'C:\Users\JYellin\re_1\tacred\data\json-enhanced\train.json')
test = pandas.read_json(r'C:\Users\JYellin\re_1\tacred\data\json-enhanced\test.json')

# no point to consider 'no_relation' sentences when considering recall ,,,
train = train[ train['relation'] != 'no_relation']
test = test[ test['relation'] != 'no_relation']


In [3]:
train_total_by_r_and_p = train.groupby(['relation','ucca_path'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['relation','total'], ascending=[True,False]).reset_index(drop=True)

In [4]:
# the following command will effectivly remove all rows in test that do not have a match

test = test.merge(train_total_by_r_and_p, on=['relation', 'ucca_path'])

In [5]:
test_total_by_relation_and_len = test.groupby(['relation','ucca_path_len'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['relation','total'], ascending=[False,False]).reset_index(drop=True)

In [6]:
test_total_by_relation = test.groupby(['relation'])['id'].count().reset_index().rename(columns={'id': 'total'}).sort_values(['total'], ascending=[False]).reset_index(drop=True)

relations = test_total_by_relation['relation'].to_list()
totals = test_total_by_relation['total'].to_list()

labels = [ '{} ({} matches)'.format(relation, total) for relation,total in zip(relations,totals) ]

In [7]:
import math

relations = test_total_by_relation['relation'].unique().tolist()
num_relations = len(relations)

num_columns = 4
num_rows = math.ceil(num_relations/num_columns)

In [26]:
layout = go.Layout(
    title='UCCA path coverage',
    width=950,
    height=3500,
    font=dict(family="Arial", size=12)
)

fig = make_subplots(
    rows=num_rows, 
    cols=num_columns,
    subplot_titles=labels,
    specs=[[{"type": "pie"}] * num_columns] * num_rows 
)

not_interested_in_output = fig['layout'].update(layout)


for ind, relation in enumerate(relations):
    
    row_num, column_num = divmod(ind,num_columns)
    
    column_num += 1
    row_num +=1 
    
    pie =  go.Pie(
            labels=test_total_by_relation_and_len[test_total_by_relation_and_len['relation']==relation]['ucca_path_len'].tolist(),
            values=test_total_by_relation_and_len[test_total_by_relation_and_len['relation']==relation]['total'].tolist()
        )
    
    fig.add_trace(
        pie,
        row=row_num, 
        col=column_num)

colors = ['#FEBFB3', '#E1396C', '#96D38C', '#D0F9B1']
colors = ['#BAA7AE', '#F2A939', '#BA688E', '#5D3D6A']
colors = ['#3C464C', '#6AA2B5', '#F05A71', '#C0AC9F']
colors = ['#FFED6F', '#CCEBC5', '#BC80BD',  '#D9D9D9',  '#FCCDE5',  '#80B1D3' ]#,  '#FDB462',  '#B3DE69',  '#FB8072'] #,  '#BEBADA',  '#FFFFB3', '#8DD3C7']

fig.update_traces(
    hoverinfo='label+percent', 
    textinfo='label+value', 
    textfont_size=10,
    marker=dict(colors=colors)
)
    
    
for i in fig['layout']['annotations']:
    i['font'] = dict(family="Arial", size=10)

    

fig.show()
    