In [28]:
### Standard libraries
import operator
from tqdm import tqdm
import numpy as np
import pandas as pd
from ast import literal_eval

In [29]:
bert_res = pd.read_csv("bert_res.csv").iloc[:,1:]
bart_res = pd.read_csv("bart_res.csv").iloc[:,1:]
w2v_res = pd.read_csv("w2v_res.csv").iloc[:,1:]
glove_res = pd.read_csv("glove_res.csv").iloc[:,1:]

In [30]:
bert_res['task'] = bert_res['task'].apply(literal_eval)
bert_res['top_10'] = bert_res['top_10'].apply(literal_eval)
glove_res['task'] = glove_res['task'].apply(literal_eval)
glove_res['top_10'] = glove_res['top_10'].apply(literal_eval)
w2v_res['task'] = w2v_res['task'].apply(literal_eval)
w2v_res['top_10'] = w2v_res['top_10'].apply(literal_eval)
bart_res['task'] = bart_res['task'].apply(literal_eval)
bart_res['top_10'] = bart_res['top_10'].apply(literal_eval)

In [31]:
joined = bert_res \
    .merge(w2v_res, how = 'inner', on = 'task') \
    .merge(glove_res, how = 'inner', on = 'task')

joined.columns = ['task', 'cos_bert', 'rank_bert', 'top10_bert', 'to_del', 
                 'cos_w2v', 'rank_w2v', 'top10_w2v', 'to_del_2',
                 'cos_glove', 'rank_glove', 'top10_glove', 'category']

del joined['to_del']
del joined['to_del_2']

bart_res.columns = ['task', 'cos_bart', 'rank_bart', 'top10_bart', 'category']
joined = joined.merge(bart_res, how = 'inner', on = 'task')

del joined['category_x']
joined['category'] = joined['category_y']
del joined['category_y']

joined = joined.drop_duplicates(['task']).reset_index(drop = True)

joined['rank_w2v'] = joined['rank_w2v'].fillna(200000)
joined['type'] = joined['category'].apply(lambda r: 'syn' if r[:4] == 'gram' else 'sem')

In [32]:
joined.head()

Unnamed: 0,task,cos_bert,rank_bert,top10_bert,cos_w2v,rank_w2v,top10_w2v,cos_glove,rank_glove,top10_glove,cos_bart,rank_bart,top10_bart,category,type
0,"(athens, greece, baghdad, iraq)",0.613978,1,"[(iraq, 0.6139779), (mesopotamia, 0.60773236),...",0.459147,9.0,"[(saddam, 0.48523029685020447), (afghanistan, ...",0.723732,1.0,"[(iraq, 0.7237322926521301), (iraqi, 0.6456960...",0.642231,1,"[(iraq, 0.64223135), (iraqi, 0.55553335), (syr...",capital-common-countries,sem
1,"(athens, greece, bangkok, thailand)",0.653616,1,"[(thailand, 0.6536159), (cambodia, 0.5840745),...",0.56846,2.0,"[(europe, 0.5729362368583679), (thailand, 0.56...",0.770864,1.0,"[(thailand, 0.7708642482757568), (thai, 0.5920...",0.69567,1,"[(thailand, 0.69567), (cambodia, 0.4977392), (...",capital-common-countries,sem
2,"(athens, greece, beijing, china)",0.527557,4,"[(tianjin, 0.54582757), (nanjing, 0.53440684),...",0.2386,19956.0,"[(europe, 0.5613299608230591), (poland, 0.5535...",0.775182,1.0,"[(china, 0.7751821279525757), (chinese, 0.6065...",0.661186,1,"[(china, 0.661186), (chinese, 0.47734934), (ja...",capital-common-countries,sem
3,"(athens, greece, berlin, germany)",0.547957,1,"[(germany, 0.5479567), (italy, 0.43626735), (s...",0.548939,2.0,"[(german, 0.5611334443092346), (germany, 0.548...",0.749873,1.0,"[(germany, 0.7498731017112732), (german, 0.587...",0.601279,1,"[(germany, 0.6012795), (france, 0.4640275), (d...",capital-common-countries,sem
4,"(athens, greece, cairo, egypt)",0.591381,1,"[(egypt, 0.5913807), (egyptians, 0.53934705), ...",0.574773,1.0,"[(egypt, 0.5747732520103455), (malta, 0.527613...",0.743846,1.0,"[(egypt, 0.7438464164733887), (egyptian, 0.596...",0.607878,1,"[(egypt, 0.60787797), (syria, 0.48144093), (mo...",capital-common-countries,sem


### Table 1

In [33]:
joined \
    .groupby('type') \
    .agg({'task': 'count',
          'rank_bert': lambda x: (x == 1).mean(),
          'rank_w2v': lambda x: (x == 1).mean(),
          'rank_glove': lambda x: (x == 1).mean(),
          'rank_bart': lambda x: (x == 1).mean()}) \
    .round(3)

Unnamed: 0_level_0,task,rank_bert,rank_w2v,rank_glove,rank_bart
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
sem,2278,0.641,0.234,0.759,0.846
syn,7244,0.754,0.667,0.692,0.825


In [34]:
joined \
    .agg({'task': 'count',
            'rank_bert': lambda x: (x == 1).mean(),
            'rank_w2v': lambda x: (x == 1).mean(),
            'rank_glove': lambda x: (x == 1).mean(),
            'rank_bart': lambda x: (x == 1).mean()}) \
    .round(3)

task          9522.000
rank_bert        0.727
rank_w2v         0.563
rank_glove       0.708
rank_bart        0.830
dtype: float64

### Table 2

In [35]:
joined \
    .groupby('type') \
    .agg({'task': 'count',
          'cos_bert': 'mean',
          'cos_w2v': 'mean',
          'cos_glove': 'mean',
          'cos_bart': 'mean'}) \
    .round(3)

Unnamed: 0_level_0,task,cos_bert,cos_w2v,cos_glove,cos_bart
type,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
sem,2278,0.5,0.504,0.6,0.525
syn,7244,0.61,0.582,0.61,0.596


In [36]:
joined \
    .agg({'task': 'count',
          'cos_bert': 'mean',
          'cos_w2v': 'mean',
          'cos_glove': 'mean',
          'cos_bart': 'mean'}) \
    .round(3)

task         9522.000
cos_bert        0.584
cos_w2v         0.564
cos_glove       0.607
cos_bart        0.579
dtype: float64