In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from utils import import_in_df, import_in_df_copy

sns.set_theme(style="darkgrid")

In [None]:
dfs = import_in_df_copy('./data_points/synthetic/', 'ntk_rf')
dfs = dfs.loc[dfs['fmap'] == 'ntk']

dfmnist = import_in_df('./data_points/MNIST/', 'ntk')
dfcifar = import_in_df('./data_points/CIFAR-10/', 'ntk')

dfmnist = dfmnist.loc[dfmnist['N'] <= 5000]
dfcifar = dfcifar.loc[dfcifar['N'] <= 1000]

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(11, 3.2))

new_palette = ['#c44e52', 'black']

sns.lineplot(x='N', y='score_test', hue='alpha_min_1', style='phi', data=dfs, ax=ax[0], palette=new_palette)
sns.lineplot(x='N', y='score_spurious', hue='alpha_min_1', style='phi', data=dfs, ax=ax[1], palette=new_palette)

sns.lineplot(x='N', y='score_spurious', data=dfmnist, ax=ax[2], label='spurious accuracy', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfmnist, ax=ax[2], label= 'Test accuracy', color=new_palette[0])

sns.lineplot(x='N', y='score_spurious', data=dfcifar, ax=ax[3], label='spurious', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfcifar, ax=ax[3], label= 'Test', color=new_palette[0])

for i in range(1, 4):
    ax[i].set_ylim(0.48, 1.02)
    ax[i].yaxis.set_ticklabels([])

handles, _ = ax[0].get_legend_handles_labels()
new_handles = [handles[1], handles[2], handles[4], handles[5]]
new_labels = [r'$\phi_2, \alpha=1/2$', r'$\phi_2, \alpha=1/4$', r'$\phi_4, \alpha=1/2$', r'$\phi_4, \alpha=1/4$']
new_handles[2].set_linestyle('--')
new_handles[2].set_color('#c44e52')

ax[0].legend(new_handles, new_labels, loc='upper left')
ax[1].legend().remove()
ax[3].legend().remove()

ax[0].set_ylim(0.48, 1.02)
ax[0].set_ylabel('')
ax[0].set_xlabel(r'$N$')
ax[0].set_xscale('log')
ax[0].set_title('Synthetic - Test accuracy')

ax[1].set_ylabel('')
ax[1].set_xlabel(r'$N$')
ax[1].set_xscale('log')
ax[1].set_title('Synthetic - spurious accuracy')

ax[2].set_ylabel('')
ax[2].set_xlabel(r'$N$')
ax[2].set_xscale('log')
ax[2].set_title('MNIST (1 and 7)')

ax[3].set_ylabel('')
ax[3].set_xlabel(r'$N$')
ax[3].set_xscale('log')
ax[3].set_title('CIFAR-10 (cats and ships)')

fig.tight_layout()
fig.savefig('NTK.pdf')

In [None]:
dfs = import_in_df_copy('./data_points/synthetic/', 'ntk_rf')
dfs = dfs.loc[dfs['fmap'] == 'rf']

dfmnist = import_in_df('./data_points/MNIST/', 'rf')
dfcifar = import_in_df('./data_points/CIFAR-10/', 'rf')

dfmnist = dfmnist.loc[dfmnist['N'] <= 5000]
dfcifar = dfcifar.loc[dfcifar['N'] <= 1000]

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(11, 3.2))

new_palette = ['#c44e52', 'black'] # red and black

sns.lineplot(x='N', y='score_test', hue='alpha_min_1', style='phi', data=dfs, ax=ax[0], palette=new_palette)
sns.lineplot(x='N', y='score_spurious', hue='alpha_min_1', style='phi', data=dfs, ax=ax[1], palette=new_palette)

sns.lineplot(x='N', y='score_spurious', data=dfmnist, ax=ax[2], label='spurious accuracy', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfmnist, ax=ax[2], label= 'Test accuracy', color=new_palette[0])

sns.lineplot(x='N', y='score_spurious', data=dfcifar, ax=ax[3], label='spurious', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfcifar, ax=ax[3], label= 'Test', color=new_palette[0])

for i in range(1, 4):
    ax[i].set_ylim(0.48, 1.02)
    ax[i].yaxis.set_ticklabels([])

handles, _ = ax[0].get_legend_handles_labels()
new_handles = [handles[1], handles[2], handles[4], handles[5]]
new_labels = [r'$\phi_2, \alpha=1/2$', r'$\phi_2, \alpha=1/4$', r'$\phi_4, \alpha=1/2$', r'$\phi_4, \alpha=1/4$']
new_handles[2].set_linestyle('--')
new_handles[2].set_color('#c44e52')

ax[0].legend(new_handles, new_labels, loc='upper left')
ax[1].legend().remove()
ax[3].legend().remove()

ax[0].set_ylim(0.48, 1.02)
ax[0].set_ylabel('')
ax[0].set_xlabel(r'$N$')
ax[0].set_xscale('log')
ax[0].set_title('Synthetic - Test accuracy')

ax[1].set_ylabel('')
ax[1].set_xlabel(r'$N$')
ax[1].set_xscale('log')
ax[1].set_title('Synthetic - spurious accuracy')

ax[2].set_ylabel('')
ax[2].set_xlabel(r'$N$')
ax[2].set_xscale('log')
ax[2].set_title('MNIST (1 and 7)')

ax[3].set_ylabel('')
ax[3].set_xlabel(r'$N$')
ax[3].set_xscale('log')
ax[3].set_title('CIFAR-10 (cats and ships)')

fig.tight_layout()
fig.savefig('RF.pdf')

In [None]:
dffs = import_in_df('./data_points/NN/', 'FC_synthetic')
dffm = import_in_df('./data_points/NN/', 'FC_MNIST')
dfcm = import_in_df('./data_points/NN/', 'CNN_MNIST')
dfcc = import_in_df('./data_points/NN/', 'CNN_CIFAR-10')

In [None]:
fig, ax = plt.subplots(1, 4, figsize=(11, 3.2))

new_palette = ['#c44e52', 'black']

sns.lineplot(x='N', y='score_spurious', data=dffs, ax=ax[0], label='spurious accuracy', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dffs, ax=ax[0], label= 'Test accuracy', color=new_palette[0])

sns.lineplot(x='N', y='score_spurious', data=dffm, ax=ax[1], label='spurious', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dffm, ax=ax[1], label= 'test', color=new_palette[0])

sns.lineplot(x='N', y='score_spurious', data=dfcm, ax=ax[2], label='spurious', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfcm, ax=ax[2], label= 'test', color=new_palette[0])

sns.lineplot(x='N', y='score_spurious', data=dfcc, ax=ax[3], label='spurious', color=new_palette[1])
sns.lineplot(x='N', y='score_test', data=dfcc, ax=ax[3], label= 'test', color=new_palette[0])

for i in range(0, 4):
    ax[i].set_ylim(0.08, 1.02)
for i in range(1, 4):
    ax[i].yaxis.set_ticklabels([])
    
ax[1].legend().remove()
ax[3].legend().remove()
ax[2].legend().remove()

ax[0].set_ylabel('')
ax[0].set_xlabel(r'$N$')
ax[0].set_xscale('log')
ax[0].set_title('Synthetic - FC')

ax[1].set_ylabel('')
ax[1].set_xlabel(r'$N$')
ax[1].set_xscale('log')
ax[1].set_title('MNIST - FC')

ax[2].set_ylabel('')
ax[2].set_xlabel(r'$N$')
ax[2].set_xscale('log')
ax[2].set_title('MNIST - CNN')


ax[3].set_ylabel('')
ax[3].set_xlabel(r'$N$')
ax[3].set_xscale('log')
ax[3].set_title('CIFAR-10 - CNN')

fig.tight_layout()
fig.savefig('NN.pdf')