In [1]:
from _plotly_future_ import v4_subplots
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
from plotly.subplots import make_subplots
import plotly.graph_objs as go

init_notebook_mode(connected=True)

In [2]:
batches = [1, 2, 3, 4, 5, 6, 7]
filter_lengths = ['6f', '7f', '8f', '9f', '10f', '11f', '12f', '13f', '14f', 'nofilter']

In [3]:
import pandas

results = {}

for filter_length in filter_lengths:
    
    file_name = filter_length
    if file_name != 'nofilter':
        file_name = file_name[:-1]
    
    all_batches = []
    
    for batch in batches:
        
        file_path = r'C:\Users\JYellin\re_1\tacred\analysis\filtering-by-path-length\results\batch{}-{}.csv'.format(batch, file_name)
        df = pandas.read_csv(file_path).sort_values(['count'], ascending=[False])
        
        all_batches.append(df)
        
    results[filter_length] = pandas.concat(all_batches)
    


In [4]:
average_results = {}

for filter_length in filter_lengths:
    
    result = results[filter_length].groupby('relation').mean().sort_values(['count'], ascending=[False])
    
    result.rename(columns={
        'precision' : 'precision-{}'.format(filter_length),
        'recall' : 'recall-{}'.format(filter_length),
        'f1' : 'f1-{}'.format(filter_length),
    }, inplace=True)
    
    #result.drop(columns=['count'], inplace=True)
    
    average_results[filter_length] = result
    

In [5]:
results_table = pandas.concat(average_results, axis=1)

In [6]:
train = pandas.read_json(r'C:\Users\JYellin\re_1\tacred\data\json-enhanced\train.json')
dev = pandas.read_json(r'C:\Users\JYellin\re_1\tacred\data\json-enhanced\dev.json')
data = pandas.concat([train,dev])

relations_count = data.groupby(['relation'])['id'].count().reset_index().rename(columns={'id': 'total'}).set_index(['relation'])
relations_count_dict = relations_count.to_dict()['total']
del relations_count_dict['no_relation']

relations_count_dict['total'] = sum(relations_count_dict.values())

In [7]:
#relations = results_table['6f']['precision-6f'].index.tolist()
#counts = results_table['6f']['count'].tolist()
#titles = ['{} ({})'.format(relation, count) for relation, count in zip(relations,counts)]

relations = sorted(relations_count_dict, key=relations_count_dict.get, reverse=True) 
titles = ['{} ({:,})'.format(relation, relations_count_dict[relation]) for relation in relations]


In [8]:
import math
num_columns = 3
num_rows = math.ceil(len(relations)/num_columns)

In [9]:
layout = go.Layout(
    title='UCCA Filtering Results',
    width=1500,
    height=3000,
    font=dict(family="Arial", size=12)
)

fig = make_subplots(
    rows=num_rows, 
    cols=num_columns,
    subplot_titles=list(titles),
)

not_interestes_in_result = fig['layout'].update(layout)


for ind, relation in enumerate(relations):
    
    row_num, column_num = divmod(ind,num_columns)
    
    column_num += 1
    row_num +=1 
    
    precision_results = []
    recall_results = []
    f1_results = []
    
    for filter_length in filter_lengths:
        precision_result = results_table[results_table.index ==relation][filter_length]['precision-{}'.format(filter_length)][0]
        recall_result = results_table[results_table.index ==relation][filter_length]['recall-{}'.format(filter_length)][0]
        f1_result = results_table[results_table.index ==relation][filter_length]['f1-{}'.format(filter_length)][0]
    
        precision_results.append(precision_result)
        recall_results.append(recall_result)
        f1_results.append(f1_result)
        
    fig.add_trace(
        go.Scatter(
            x=filter_lengths,
            y=precision_results,
            name='{} precision'.format(relation),
            line=dict(color='#CCEBC5')),
        row=row_num,
        col=column_num)        

    fig.add_trace(
        go.Scatter(
            x=filter_lengths,
            y=recall_results,
            name='{} recall'.format(relation),
            line=dict(color='#BC80BD')),
        row=row_num,
        col=column_num)        

    fig.add_trace(
        go.Scatter(
            x=filter_lengths,
            y=f1_results,
            name='{} f1'.format(relation),
            line=dict(color='#FCCDE5')),
        row=row_num,
        col=column_num)        

    
        
for i in fig['layout']['annotations']:
    i['font'] = dict(family="Arial", size=10)

    

fig.show()
    