In [32]:
import os
import numpy as np
import matplotlib.pyplot as plt
import mne
import pickle
import seaborn as sns
import pandas as pd
from scipy import stats
from statannotations.Annotator import Annotator

plt.rcParams.update({'font.size': 14})

# Accuracies

In [71]:
df = pd.read_csv('excel_data.txt', sep="\t", header=0)
df = df.stack().reset_index()
df = df.rename(columns={'level_0': 'subject', 'level_1': 'model', 0: 'Validation accuracy'})

model_type = []
for i in range(len(df)):
    if 'emb' in df['model'][i]:
        model_type.append('group\nembedding')
    elif 'group' in df['model'][i]:
        model_type.append('group')
    else:
        model_type.append('subject')
df['model type'] = model_type

In [74]:
df[df['model']=='nonlin-group-emb']

Unnamed: 0,subject,model,Validation accuracy,model type
5,0,nonlin-group-emb,0.490662,group\nembedding
12,1,nonlin-group-emb,0.318117,group\nembedding
19,2,nonlin-group-emb,0.112696,group\nembedding
26,3,nonlin-group-emb,0.522857,group\nembedding
33,4,nonlin-group-emb,0.705128,group\nembedding
40,5,nonlin-group-emb,0.47151,group\nembedding
47,6,nonlin-group-emb,0.637269,group\nembedding
54,7,nonlin-group-emb,0.212554,group\nembedding
61,8,nonlin-group-emb,0.46495,group\nembedding
68,9,nonlin-group-emb,0.454416,group\nembedding


In [31]:
%matplotlib widget

    
# Putting the parameters in a dictionary avoids code duplication
# since we use the same for `sns.boxplot` and `Annotator` calls
plot_params = {
    'kind':    'violin',
    'aspect':  2,
    'cut':     0,
    'ci':      None,
    'scale':   'area',
    'hue':     'model type',
    'dodge':   False,
    'data':    df,
    'x':       'model',
    'y':       'Validation accuracy'
}

g = sns.catplot(**plot_params)

ax = g.axes[0][0]
ax.axhline(0.008, ls='-', color='black', label='chance')
plt.ylim(0, 1)
plt.xlabel('')
plt.text(6.77,0.001,'chance')
plt.xticks(plt.xticks()[0], ['linear\nsubject',
                             'nonlinear\nsubject',
                             'linear\ngroup',
                             'nonlinear\ngroup',
                             'linear\ngroup-emb',
                             'nonlinear\ngroup-emb',
                             'nonlinear\ngroup-emb\nfinetuned'])

ymin = 0.02
ymax = 0.7
alpha = 0.5
dash = '--'
color = 'red'
ax.axvline(1.5, ymin, ymax, ls=dash, color=color, alpha=alpha)
ax.axvline(3.5, ymin, ymax, ls=dash, color=color, alpha=alpha)
plt.text(0.22,0.87,'subject models')
plt.text(2,0.8,'group models')
plt.text(3.7,0.8,'group models\nwith embedding')

# which pairs to computer stats on
pairs = [('lin-subject', 'nonlin-subject'),
         ('lin-subject', 'nonlin-group-emb'),
         ('lin-subject', 'nonlin-group-emb finetuned'),
         ('nonlin-group-emb', 'nonlin-group-emb finetuned'),
         ('nonlin-group-emb', 'nonlin-group')]

# Add statistics annotations
annotator = Annotator(ax, pairs, data=df, x='model', y='Validation accuracy')
annotator.configure(test='t-test_paired', verbose=True, line_offset_to_group=10).apply_and_annotate()

plt.savefig('group_acc.pdf', format='pdf')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

p-value annotation legend:
      ns: p <= 1.00e+00
       *: 1.00e-02 < p <= 5.00e-02
      **: 1.00e-03 < p <= 1.00e-02
     ***: 1.00e-04 < p <= 1.00e-03
    ****: p <= 1.00e-04

lin-subject vs. nonlin-subject: t-test paired samples, P_val:5.736e-04 t=4.428e+00
nonlin-group-emb vs. nonlin-group-emb finetuned: t-test paired samples, P_val:1.127e-06 t=-8.135e+00
nonlin-group vs. nonlin-group-emb: t-test paired samples, P_val:1.911e-06 t=-7.773e+00
lin-subject vs. nonlin-group-emb: t-test paired samples, P_val:1.285e-02 t=2.850e+00
lin-subject vs. nonlin-group-emb finetuned: t-test paired samples, P_val:1.066e-03 t=-4.108e+00


# Generalization to new subject

In [76]:
path = os.path.join('..', 'results', 'cichy_epoched', 'indiv_wavenetlinear_MNN', 'val_loss_general.npy')
accs = np.load(path)
train1 = [0.591525424, 0.303672316, 0.121468925, 0.680790966, 0.885593221, 0.662429377, 0.730225995, 0.159604517, 0.579096052, 0.627118642, 0.223163842, 0.151129942, 0.06497175, 0.483050848, 0.412429377]
accs = np.concatenate((accs, np.array(train1).reshape(-1, 1)), axis=1)
chance = [0.00847] * 15
accs = np.concatenate((np.array(chance).reshape(-1, 1), accs), axis=1)

In [77]:
accs_df = pd.DataFrame(accs)
accs_df = accs_df.stack().reset_index()
accs_df = accs_df.rename(columns={'level_0': 'subject', 'level_1': 'Training ratio', 0: 'Validation accuracy'})
accs_df['level'] = ['subject'] * len(accs_df)

In [159]:
accs_df

Unnamed: 0,level_0,level_1,0
0,0,0,0.008470
1,0,1,0.008475
2,0,2,0.013559
3,0,3,0.050847
4,0,4,0.116949
...,...,...,...
160,14,6,0.162429
161,14,7,0.213277
162,14,8,0.283898
163,14,9,0.331921


In [78]:
def create_df(accsg, level):
    # need to get actual subjects
    order = [10, 7, 3, 11, 8, 4, 12, 9, 5, 13, 1, 14, 2, 6, 0]
    accsg_df = pd.DataFrame(accsg[order, :])
    accsg_df = accsg_df.stack().reset_index()
    accsg_df = accsg_df.rename(columns={'level_0': 'subject', 'level_1': 'Training ratio', 0: 'Validation accuracy'})
    accsg_df['level'] = [level] * len(accsg_df)
    
    return accsg_df

In [79]:
path = os.path.join('..', 'results', 'cichy_epoched', 'all_wavenet_semb_general', 'val_loss_general.npy')
accsg = np.load(path)
group_emb = create_df(accsg, 'group-emb')

path = os.path.join('..', 'results', 'cichy_epoched', 'all_wavenet_general', 'val_loss_general.npy')
accsg = np.load(path)
group = create_df(accsg, 'group')

In [80]:
df = pd.concat((accs_df, group_emb, group), ignore_index=True)

In [81]:
df['Training ratio'] = df['Training ratio'].astype(float)/10

In [29]:
p_values = []
for i in range(11):
    test1 = df['Validation accuracy'][(df['Training ratio'] == i) & (df['level'] == 'group all')]
    test2 = df['Validation accuracy'][(df['Training ratio'] == i) & (df['level'] == 'group-emb all')]
    
    p_values.append(stats.ttest_rel(test1, test2)[1] * 11)

In [30]:
p_values

[0.0060336492775827375,
 1.351510411629857,
 0.03169837884573075,
 4.3785466680206016,
 2.918205427079515,
 0.10703676926544689,
 0.020678270815296832,
 0.055122900986178884,
 0.11919219652472021,
 0.23726568570982257,
 nan]

In [83]:
%matplotlib widget
g = sns.relplot(
    data=df, kind="line", hue='level',
    x="Training ratio", y="Validation accuracy", n_boot=1000, aspect=1.2, ci=95
)
ax = g.axes[0][0]
plt.axhline(0.6, 0.02, 0.7, color='black')
plt.text(0,0.61,'group>subject (p<0.05)')
plt.ylim(0, 0.6)
plt.xlabel('Training set ratio')



ax.axhline(0.008, ls='-', color='black', label='chance')
ax.legend(loc='upper left')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

<matplotlib.legend.Legend at 0x7fa9cc264df0>

In [84]:
plt.savefig('generalization.pdf', format='pdf')