In [2]:
import matplotlib
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

static_df = pd.read_csv('static_number_experiment.csv')
static_df=static_df[['data_set', 'best_model_score_0', 'best_model_score_1', 'best_model_score_2', 'best_model_score_3', 'best_model_score_4','avg_saved_prototypes']]
static_df.head()

Unnamed: 0,data_set,best_model_score_0,best_model_score_1,best_model_score_2,best_model_score_3,best_model_score_4,avg_saved_prototypes
0,rottentomatoes,0.7335,0.73949,0.73949,0.66066,0.74395,12
1,amazon,0.88249,0.874,0.882,0.52816,0.8755,12
2,hotel,0.74947,0.67052,0.88736,0.76606,0.67017,6
3,imdb,0.85399,0.85659,0.846,0.86119,0.7778,12
4,rottentomatoes,0.72583,0.73199,0.614,0.57383,0.74229,6


In [3]:
static_df = pd.melt(static_df, id_vars=['data_set','avg_saved_prototypes'], value_vars=['best_model_score_1', 'best_model_score_2', 'best_model_score_3', 'best_model_score_4'])
static_df.columns = ['dataset','n', 'variable', 'acc']
static_df = static_df[['dataset','n', 'acc']]
static_df.head()

Unnamed: 0,dataset,n,acc
0,rottentomatoes,12,0.73949
1,amazon,12,0.874
2,hotel,6,0.67052
3,imdb,12,0.85659
4,rottentomatoes,6,0.73199


In [4]:
dynamic_df = pd.read_csv('dynamic_number_experiment.csv')
dynamic_df=dynamic_df[['data_set','best_model_score_0', 'best_model_score_1', 'best_model_score_2', 'best_model_score_3', 'best_model_score_4','number_of_prototypes_0','number_of_prototypes_1','number_of_prototypes_2','number_of_prototypes_3','number_of_prototypes_4']]
dynamic_df.head()

Unnamed: 0,data_set,best_model_score_0,best_model_score_1,best_model_score_2,best_model_score_3,best_model_score_4,number_of_prototypes_0,number_of_prototypes_1,number_of_prototypes_2,number_of_prototypes_3,number_of_prototypes_4
0,hotel,0.89789,0.90736,0.89894,0.91464,0.89989,23,22,23,20,21
1,rottentomatoes,0.7325,0.73566,0.73299,0.734,0.74195,8,9,10,8,11
2,yelp,0.83633,0.83399,0.82966,0.83999,0.83197,10,11,7,9,11
3,amazon,0.88533,0.88349,0.8765,0.88116,0.87316,8,14,10,11,9
4,imdb,0.85879,0.85159,0.8488,0.85219,0.85379,10,12,10,6,9


In [5]:
data = []
for id, row in dynamic_df.iterrows():
    for i in range(5):
        data.append((row['data_set'], row[f'best_model_score_{i}'], row[f'number_of_prototypes_{i}']))
dynamic_df=pd.DataFrame(data)
dynamic_df.columns=['dataset','acc','n']
dynamic_df.head()

Unnamed: 0,dataset,acc,n
0,hotel,0.89789,23
1,hotel,0.90736,22
2,hotel,0.89894,23
3,hotel,0.91464,20
4,hotel,0.89989,21


In [6]:
cnn_accuracy = {
    'imdb': 0.893,
    'amazon': 0.911,
    'yelp': 0.867,
    'rottentomatoes': 0.776,
    'hotel': 0.929
}

titles = ['IMDB','Amazon Reviews', 'Yelp Reviews', 'Rotten Tomatoes', 'Hotel Reviews']

In [20]:
sns.set_style('whitegrid')
sns.set_context('paper', font_scale=1.4,  rc={"lines.linewidth": 1})
pd.set_option('display.max_rows', 110)

fig, axs = plt.subplots(1, 5, figsize=(19,3))

for i, ((ds, cnn_acc), ax) in enumerate(zip(cnn_accuracy.items(), axs)):
    dyn_df = dynamic_df[dynamic_df['dataset']==ds]
    stat_df = static_df[static_df['dataset']==ds]

    sns.lineplot(x='n', y='acc', data=stat_df, err_style='bars', ci='sd', label='Static', ax=ax, legend=False)
    ax.set_xlabel('Number of prototypes')
    ax.set_ylabel('')
    if i==0:
        ax.set_ylabel('Accuracy')
    ax.axhline(cnn_acc, color='gray',linestyle='--', label='CNN')
    ax.errorbar([dyn_df['n'].mean()], [dyn_df['acc'].mean()], yerr=[dyn_df['acc'].std()], xerr=[dyn_df['n'].std()],
                fmt='.', color='orange', label='Dynamic', markersize=5, linewidth=1.3)
    ax.set_xscale('log', base=2)
    ax.set_xticks([2, 4,8,16,32,64])
    ax.get_xaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter())
    ax.set_title(titles[i])
    ax.set_ylim(0.5)

plt.legend(title='Legend', bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.)
# plt.show()
# fig.text(0.5, 0.04, 'Number of prototypes', ha='center', va='center')
plt.savefig('nprotos.png', bbox_inches='tight',dpi=300)
plt.close()
