In [1]:
import pandas as pd
import json
import textstat as txt
from itertools import groupby

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

import pickle

# Loading TACRED dataset

In [2]:
train_data = json.load(open('./../../dataset/tacred/json/train.json'))
print("Number of Training instances :: {}".format(len(train_data)))

dev_data = json.load(open('./../../dataset/tacred/json/dev.json'))
print("Number of Dev instances :: {}".format(len(dev_data)))

test_data = json.load(open('./../../dataset/tacred/json/test.json'))
print("Number of Test instances :: {}".format(len(test_data)))

Number of Training instances :: 68124
Number of Dev instances :: 22631
Number of Test instances :: 15509


In [3]:
train_rel_freq = dict()
for eg in train_data:
    train_rel_freq[eg['relation']] = train_rel_freq.get(eg['relation'], 0) + 1
print(sum(train_rel_freq.values()))

dev_rel_freq = dict()
for eg in dev_data:
    dev_rel_freq[eg['relation']] = dev_rel_freq.get(eg['relation'], 0) + 1
print(sum(dev_rel_freq.values()))

test_rel_freq = dict()
for eg in test_data:
    test_rel_freq[eg['relation']] = test_rel_freq.get(eg['relation'], 0) + 1
print(sum(test_rel_freq.values()))

68124
22631
15509


In [4]:
for key , val in train_rel_freq.items():
    print(key, val)

org:founded_by 124
no_relation 55112
per:employee_of 1524
org:alternate_names 808
per:cities_of_residence 374
per:children 211
per:title 2443
per:siblings 165
per:religion 53
per:age 390
org:website 111
per:stateorprovinces_of_residence 331
org:member_of 122
org:top_members/employees 1890
per:countries_of_residence 445
org:city_of_headquarters 382
org:members 170
org:country_of_headquarters 468
per:spouse 258
org:stateorprovince_of_headquarters 229
org:number_of_employees/members 75
org:parents 286
org:subsidiaries 296
per:origin 325
org:political/religious_affiliation 105
per:other_family 179
per:stateorprovince_of_birth 38
org:dissolved 23
per:date_of_death 134
org:shareholders 76
per:alternate_names 104
per:parents 152
per:schools_attended 149
per:cause_of_death 117
per:city_of_death 81
per:stateorprovince_of_death 49
org:founded 91
per:country_of_birth 28
per:date_of_birth 63
per:city_of_birth 65
per:charges 72
per:country_of_death 6


# Heat-Map Analysis of Confusion Matrix

#### List of cmap values

Accent, Accent_r, Blues, Blues_r, BrBG, BrBG_r, BuGn, BuGn_r, BuPu, BuPu_r, CMRmap, CMRmap_r, Dark2, Dark2_r, GnBu, GnBu_r, Greens, Greens_r, Greys, Greys_r, OrRd, OrRd_r, Oranges, Oranges_r, PRGn, PRGn_r, Paired, Paired_r, Pastel1, Pastel1_r, Pastel2, Pastel2_r, PiYG, PiYG_r, PuBu, PuBuGn, PuBuGn_r, PuBu_r, PuOr, PuOr_r, PuRd, PuRd_r, Purples, Purples_r, RdBu, RdBu_r, RdGy, RdGy_r, RdPu, RdPu_r, RdYlBu, RdYlBu_r, RdYlGn, RdYlGn_r, Reds, Reds_r, Set1, Set1_r, Set2, Set2_r, Set3, Set3_r, Spectral, Spectral_r, Wistia, Wistia_r, YlGn, YlGnBu, YlGnBu_r, YlGn_r, YlOrBr, YlOrBr_r, YlOrRd, YlOrRd_r, afmhot, afmhot_r, autumn, autumn_r, binary, binary_r, bone, bone_r, brg, brg_r, bwr, bwr_r, cividis, cividis_r, cool, cool_r, coolwarm, coolwarm_r, copper, copper_r, cubehelix, cubehelix_r, flag, flag_r, gist_earth, gist_earth_r, gist_gray, gist_gray_r, gist_heat, gist_heat_r, gist_ncar, gist_ncar_r, gist_rainbow, gist_rainbow_r, gist_stern, gist_stern_r, gist_yarg, gist_yarg_r, gnuplot, gnuplot2, gnuplot2_r, gnuplot_r, gray, gray_r, hot, hot_r, hsv, hsv_r, icefire, icefire_r, inferno, inferno_r, jet, jet_r, magma, magma_r, mako, mako_r, nipy_spectral, nipy_spectral_r, ocean, ocean_r, pink, pink_r, plasma, plasma_r, prism, prism_r, rainbow, rainbow_r, rocket, rocket_r, seismic, seismic_r, spring, spring_r, summer, summer_r, tab10, tab10_r, tab20, tab20_r, tab20b, tab20b_r, tab20c, tab20c_r, terrain, terrain_r, twilight, twilight_r, twilight_shifted, twilight_shifted_r, viridis, viridis_r, vlag, vlag_r, winter, winter_r

## TACRED Analysis

In [5]:
LABEL_TO_ID = {'no_relation': 0, 'per:title': 1, 'org:top_members/employees': 2, 'per:employee_of': 3, 'org:alternate_names': 4, 'org:country_of_headquarters': 5, 'per:countries_of_residence': 6, 'org:city_of_headquarters': 7, 'per:cities_of_residence': 8, 'per:age': 9, 'per:stateorprovinces_of_residence': 10, 'per:origin': 11, 'org:subsidiaries': 12, 'org:parents': 13, 'per:spouse': 14, 'org:stateorprovince_of_headquarters': 15, 'per:children': 16, 'per:other_family': 17, 'per:alternate_names': 18, 'org:members': 19, 'per:siblings': 20, 'per:schools_attended': 21, 'per:parents': 22, 'per:date_of_death': 23, 'org:member_of': 24, 'org:founded_by': 25, 'org:website': 26, 'per:cause_of_death': 27, 'org:political/religious_affiliation': 28, 'org:founded': 29, 'per:city_of_death': 30, 'org:shareholders': 31, 'org:number_of_employees/members': 32, 'per:date_of_birth': 33, 'per:city_of_birth': 34, 'per:charges': 35, 'per:stateorprovince_of_death': 36, 'per:religion': 37, 'per:stateorprovince_of_birth': 38, 'per:country_of_birth': 39, 'org:dissolved': 40, 'per:country_of_death': 41}
ID_TO_LABEL = {val:key for key, val in LABEL_TO_ID.items()}

### PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/parnn')

#### List of relations true-positives <= 50% and false-negative  > 20%
- diagonal entries <= 50%
- prediction: no_relation > 20%

In [None]:
rels_fn = list()
for i in range(1, len(LABEL_TO_ID)):
    if (round(df.iloc[i,i], 2) <= 50) and (round(df.iloc[i,0], 2) > 20):
        rels_fn.append(ID_TO_LABEL[i])

parnn_fn = pd.DataFrame({'relations':rels_fn, 
              'train':[train_rel_freq[r] for r in rels_fn], 
              'dev':[dev_rel_freq[r] for r in rels_fn],
              'test':[test_rel_freq[r] for r in rels_fn]})

parnn_fn

#### List of relations true-positives <= 50% and false-positives  > 20%
- diagonal entries <= 50%
- prediction: some other relation > 20%

In [None]:
rels_fp = list()
for i in range(1, len(LABEL_TO_ID)):
    if (round(df.iloc[i,i], 2) <= 50) and (((round(df.iloc[i,1:i], 2) > 20).any()) or ((round(df.iloc[i,i+1:], 2) > 20).any())):
        rels_fp.append(ID_TO_LABEL[i])

parnn = pd.DataFrame({'relations':rels_fp, 
              'train':[train_rel_freq[r] for r in rels_fp], 
              'dev':[dev_rel_freq[r] for r in rels_fp],
              'test':[test_rel_freq[r] for r in rels_fp]})
parnn

#### Common relations for PARNN False-Negatives and False Positives

In [None]:
set(rels_fn).intersection(set(rels_fp))

### LSTM

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/lstm', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/lstm')

## BiLSTM

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/bilstm', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/bilstm')

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/cgcn')

#### List of relations true-positives <= 50% and false-negative  > 20%
- diagonal entries <= 50%
- prediction: no_relation > 20%

In [None]:
rels_fn = list()
for i in range(1, len(LABEL_TO_ID)):
    if (round(df.iloc[i,i], 2) <= 50) and (round(df.iloc[i,0], 2) > 20):
        rels_fn.append(ID_TO_LABEL[i])

cgcn_fn = pd.DataFrame({'relations':rels_fn, 
              'train':[train_rel_freq[r] for r in rels_fn], 
              'dev':[dev_rel_freq[r] for r in rels_fn],
              'test':[test_rel_freq[r] for r in rels_fn]})

cgcn_fn

#### List of relations true-positives <= 50% and false-positives  > 20%
- diagonal entries <= 50%
- prediction: some other relation > 20%

In [None]:
rels_fp = list()
for i in range(1, len(LABEL_TO_ID)):
    if (round(df.iloc[i,i], 2) <= 50) and (((round(df.iloc[i,1:i], 2) > 20).any()) or ((round(df.iloc[i,i+1:], 2) > 20).any())):
        rels_fp.append(ID_TO_LABEL[i])

cgcn = pd.DataFrame({'relations':rels_fp, 
              'train':[train_rel_freq[r] for r in rels_fp], 
              'dev':[dev_rel_freq[r] for r in rels_fp],
              'test':[test_rel_freq[r] for r in rels_fp]})

cgcn

#### Common relations for CGCN False-Negatives and False-Positives

In [None]:
set(rels_fn).intersection(set(rels_fp))

### GCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/gcn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/gcn')

### CNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/cnn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/cnn')

### Common relations between PARNN and CGCN

#### False Negatives

In [None]:
parnn_fn.merge(cgcn_fn)

#### False Positives

In [None]:
parnn.merge(cgcn)

## TACRED excluding no_relation

### PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn-all_positive', 'rb')))
df = df.loc[1:,:]
df = df.set_index(0)
df.index.name = 'labels'
df = df.loc[:,2:]
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/all-positive-parnn')

#### List of relations true-positives <= 50% and false-positives  > 20%
- diagonal entries <= 50%
- prediction: some other relation > 20%

In [None]:
rels_fp = list()
for i in range(1, len(LABEL_TO_ID)-1):
    if (round(df.iloc[i,i], 2) <= 50) and (((round(df.iloc[i,1:i], 2) > 20).any()) or ((round(df.iloc[i,i+1:], 2) > 20).any())):
        rels_fp.append(ID_TO_LABEL[i+1])

parnn_allpos = pd.DataFrame({'relations':rels_fp, 
                             'train':[train_rel_freq[r] for r in rels_fp], 
                             'dev':[dev_rel_freq[r] for r in rels_fp],
                             'test':[test_rel_freq[r] for r in rels_fp]})
parnn_allpos

#### Intersection with PARNN

In [None]:
parnn.merge(parnn_allpos)

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn-all_positive', 'rb')))
df = df.loc[1:,:]
df = df.set_index(0)
df.index.name = 'labels'
df = df.loc[:,2:]
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/all-positive-cgcn')

#### List of relations true-positives <= 50% and false-positives  > 20%
- diagonal entries <= 50%
- prediction: some other relation > 20%

In [None]:
rels_fp = list()
for i in range(1, len(LABEL_TO_ID)-1):
    if (round(df.iloc[i,i], 2) <= 50) and (((round(df.iloc[i,1:i], 2) > 20).any()) or ((round(df.iloc[i,i+1:], 2) > 20).any())):
        rels_fp.append(ID_TO_LABEL[i+1])

cgcn_allpos = pd.DataFrame({'relations':rels_fp, 
                            'train':[train_rel_freq[r] for r in rels_fp], 
                            'dev':[dev_rel_freq[r] for r in rels_fp],
                            'test':[test_rel_freq[r] for r in rels_fp]})

cgcn_allpos

#### Intersection with CGCN

In [None]:
cgcn.merge(cgcn_allpos)

## TACREV Analysis

## PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/tacrev-parnn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
# Number of test instances in test set for 'per:country_of_birth' = 0
df.loc['per:country_of_birth'] = 0
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/tacrev-parnn')

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/tacrev-cgcn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
# Number of test instances in test set for 'per:country_of_birth' = 0
df.loc['per:country_of_birth'] = 0
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/tacrev-cgcn')

## Re-TACRED Analysis

In [None]:
RE_LABEL_TO_ID = {'no_relation': 0, 'org:members': 1, 'per:siblings': 2, 'per:spouse': 3, 'org:country_of_branch': 4, 'per:country_of_death': 5, 'per:parents': 6, 'per:stateorprovinces_of_residence': 7, 'org:top_members/employees': 8, 'org:dissolved': 9, 'org:number_of_employees/members': 10, 'per:stateorprovince_of_death': 11, 'per:origin': 12, 'per:children': 13, 'org:political/religious_affiliation': 14, 'per:city_of_birth': 15, 'per:title': 16, 'org:shareholders': 17, 'per:employee_of': 18, 'org:member_of': 19, 'org:founded_by': 20, 'per:countries_of_residence': 21, 'per:other_family': 22, 'per:religion': 23, 'per:identity': 24, 'per:date_of_birth': 25, 'org:city_of_branch': 26, 'org:alternate_names': 27, 'org:website': 28, 'per:cause_of_death': 29, 'org:stateorprovince_of_branch': 30, 'per:schools_attended': 31, 'per:country_of_birth': 32, 'per:date_of_death': 33, 'per:city_of_death': 34, 'org:founded': 35, 'per:cities_of_residence': 36, 'per:age': 37, 'per:charges': 38, 'per:stateorprovince_of_birth': 39}
RE_ID_TO_LABEL = {val:key for key, val in RE_LABEL_TO_ID.items()}

### PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix-156/retacred-parnn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df.loc['per:country_of_birth']
df = df.rename(columns={i:RE_ID_TO_LABEL[i-1] for i in range(1,41)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
# Number of test instances in test set for 'per:country_of_birth' = 0
df.loc['per:country_of_birth'] = 0
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/retacred-parnn')

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix-156/retacred-cgcn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:RE_ID_TO_LABEL[i-1] for i in range(1,41)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
# Number of test instances in test set for 'per:country_of_birth' = 0
df.loc['per:country_of_birth'] = 0
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/retacred-cgcn')

# Re-Annotation Analysis

## PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn-reann-cos', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/parnn-reann')

In [None]:
parnn = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn', 'rb')))
parnn = parnn.set_index(0)
parnn.index.name = 'labels'
parnn = parnn.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
parnn = parnn.div(parnn.sum(axis=1), axis=0).round(3) * 100

parnn_re = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn-reann-cos', 'rb')))
parnn_re = parnn_re.set_index(0)
parnn_re.index.name = 'labels'
parnn_re = parnn_re.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
parnn_re = parnn_re.div(parnn_re.sum(axis=1), axis=0).round(3) * 100

df = parnn_re - parnn

In [None]:
import pickle
reann_ids = pickle.load(open('./reannotation-ids/Preplace_cosine.pkl', 'rb'))
print(len(reann_ids))
re_train_rel_freq = dict()
for eg in train_data:
    sid = eg['id']
    rel = eg['relation']
    if sid in reann_ids:
        rel = reann_ids[sid]        
    re_train_rel_freq[rel] = re_train_rel_freq.get(rel, 0) + 1
print(sum(train_rel_freq.values()))

X = [rel.strip() for rel in list(parnn.index)]
Y = [train_rel_freq[rel] for rel in X]
Y[0] = 0
Y_re = [re_train_rel_freq[rel] for rel in X]
Y_re[0] = 0

freq = pd.DataFrame({'labels':X, 'original':Y, 'reann':Y_re})

In [None]:
def show_values_on_bars(axs):
    def _show_on_single_plot(ax):        
        for p in ax.patches:
            _x = p.get_x() + p.get_width() / 2
            _y = p.get_y() + p.get_height()
            value = '{}'.format(int(p.get_height()))
            ax.text(_x, _y, value, ha="center") 

    if isinstance(axs, np.ndarray):
        for idx, ax in np.ndenumerate(axs):
            _show_on_single_plot(ax)
    else:
        _show_on_single_plot(axs)


In [None]:
fig, (ax1, ax2) = plt.subplots(2,1,figsize=(25,30), gridspec_kw={'height_ratios': [2, 1]})
ax1 = plt.subplot(211)
ax1 = sns.heatmap(df, cmap='BrBG', annot=True, fmt='g')
# plt.savefig('./images/confusion_matrix/parnn-reann-diff')

# f, ax = plt.subplots(figsize=(15, 6))
ax2 = plt.subplot(212)
ax2 = sns.set_color_codes("pastel")
ax2 = sns.barplot(x='labels', y='reann', data=freq, label="Reannotated", color="b")
ax2 = sns.set_color_codes("muted")
ax2 = sns.barplot(x='labels', y='original', data=freq, label="Original", color="b")
show_values_on_bars(ax2)
ax2 = sns.despine(left=True, bottom=True)
ax2 = plt.xticks(rotation=90, size=10)
ax2 = plt.legend(ncol=2, loc="upper right", frameon=True)


plt.savefig('./images/confusion_matrix/parnn-reann-diff')

## CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn-reann-knn', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/cgcn-reann')

In [None]:
cgcn = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn', 'rb')))
cgcn = cgcn.set_index(0)
cgcn.index.name = 'labels'
cgcn = cgcn.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
cgcn = cgcn.div(cgcn.sum(axis=1), axis=0).round(3) * 100

cgcn_re = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn-reann-knn', 'rb')))
cgcn_re = cgcn_re.set_index(0)
cgcn_re.index.name = 'labels'
cgcn_re = cgcn_re.rename(columns={i:ID_TO_LABEL[i-1] for i in range(1,43)})
cgcn_re = cgcn_re.div(cgcn_re.sum(axis=1), axis=0).round(3) * 100

df = cgcn_re - cgcn

In [None]:
import pickle
reann_ids = pickle.load(open('./reannotation-ids/Creplace_knn.pkl', 'rb'))
re_train_rel_freq = dict()
for eg in train_data:
    sid = eg['id']
    rel = eg['relation']
    if sid in reann_ids:
        rel = reann_ids[sid]
    re_train_rel_freq[rel] = re_train_rel_freq.get(rel, 0) + 1
print(sum(train_rel_freq.values()))

X = [rel.strip() for rel in list(cgcn.index)]
Y = [train_rel_freq[rel] for rel in X]
Y[0] = 0
Y_re = [re_train_rel_freq[rel] for rel in X]
Y_re[0] = 0

freq = pd.DataFrame({'labels':X, 'original':Y, 'reann':Y_re})

In [None]:
fig, (ax1, ax2) = plt.subplots(2,1,figsize=(25,30), gridspec_kw={'height_ratios': [2, 1]})
ax1 = plt.subplot(211)
ax1 = sns.heatmap(df, cmap='BrBG', annot=True, fmt='g')
# plt.savefig('./images/confusion_matrix/parnn-reann-diff')

# f, ax = plt.subplots(figsize=(15, 6))
ax2 = plt.subplot(212)
ax2 = sns.set_color_codes("pastel")
ax2 = sns.barplot(x='labels', y='reann', data=freq, label="Reannotated", color="b")
ax2 = sns.set_color_codes("muted")
ax2 = sns.barplot(x='labels', y='original', data=freq, label="Original", color="b")
show_values_on_bars(ax2)
ax2 = sns.despine(left=True, bottom=True)
ax2 = plt.xticks(rotation=90, size=10)
ax2 = plt.legend(ncol=2, loc="upper right", frameon=True)

plt.savefig('./images/confusion_matrix/cgcn-reann-diff')

# Relabel

## TACRED

In [None]:
RELABELED_LABEL_TO_ID = {'no_relation': 0, 'per:title': 1, 'org:top_members/employees': 2, 'per:employee_of': 3, 'org:alternate_names': 4, 
            'org:city_of_headquarters': 5, 'per:locations_of_residence': 6, 'per:age': 7, 'per:origin': 8, 'org:subsidiaries': 9,
            'org:parents': 10, 'per:other_family': 11, 'per:alternate_names': 12, 'org:members': 13, 'per:schools_attended': 14, 
            'per:date_of_death': 15, 'org:member_of': 16, 'org:founded_by': 17, 'org:website': 18, 'per:cause_of_death': 19, 
            'org:political/religious_affiliation': 20, 'org:founded': 21, 'per:location_of_death': 22, 'org:shareholders': 23, 
            'org:number_of_employees/members': 24, 'per:date_of_birth': 25, 'per:location_of_birth': 26, 'per:charges': 27, 
            'per:religion': 28, 'org:dissolved': 29, 'per:children': 30, 'per:parents' : 31, 'per:siblings' : 32, 'per:spouse' : 33,
            'org:country_of_headquarters': 34, 'org:stateorprovince_of_headquarters': 35}
RELABELED_ID_TO_LABEL = {val:key for key, val in RELABELED_LABEL_TO_ID.items()}

### PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/parnn-relabel', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:RELABELED_ID_TO_LABEL[i-1] for i in range(1,36)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/parnn-relabel')

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/cgcn-relabel', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:RELABELED_ID_TO_LABEL[i-1] for i in range(1,36)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/cgcn-relabel')

## ReTACRED

In [None]:
RE_RELABELED_LABEL_TO_ID = {'no_relation': 0, 'org:members': 1, 'per:siblings': 2, 'per:spouse': 3, 'org:country_of_branch': 4, 
						 'per:location_of_death': 5, 'per:parents': 6, 'per:locations_of_residence': 7, 'org:top_members/employees': 8,
						 'org:dissolved': 9, 'org:number_of_employees/members': 10, 'per:origin': 11, 'per:children': 12, 
						 'org:political/religious_affiliation': 13, 'per:location_of_birth': 14, 'per:title': 15, 'org:shareholders': 16, 
						 'per:employee_of': 17, 'org:member_of': 18, 'org:founded_by': 19, 'per:other_family': 20, 'per:religion': 21, 
						 'per:identity': 22, 'per:date_of_birth': 23, 'org:city_of_branch': 24, 'org:alternate_names': 25, 'org:website': 26, 
						 'per:cause_of_death': 27, 'org:stateorprovince_of_branch': 28, 'per:schools_attended': 29, 'per:date_of_death': 30, 
						 'org:founded': 31, 'per:age': 32, 'per:charges': 33}

RE_RELABELED_ID_TO_LABEL = {val:key for key, val in RELABELED_LABEL_TO_ID.items()}

### PARNN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/retacred-parnn-relabel.pkl', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:RE_RELABELED_ID_TO_LABEL[i-1] for i in range(1,36)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/retacred-parnn-relabel')

### CGCN

In [None]:
df = pd.DataFrame(pickle.load(open('./confusion-matrix/retacred-cgcn-relabel.pkl', 'rb')))
df = df.set_index(0)
df.index.name = 'labels'
df = df.rename(columns={i:RE_RELABELED_ID_TO_LABEL[i-1] for i in range(1,36)})
df = df.div(df.sum(axis=1), axis=0).round(3) * 100
plt.figure(figsize=(25,15))
sns.heatmap(df, cmap='BuPu', annot=True, fmt='g')
plt.savefig('./images/confusion_matrix/retacred-cgcn-relabel')

# Rough