In [None]:
from lib.utils import *
from lib.models import *
from lib.ekyn import *
from lib.env import *
from lib.datasets import *
def load_spindle_eeg_label_pair_2s(cohort='A',subject='1'):
    # verify each 2 second epoch has all the same label
    # all([all(yi[0]==yi) for yi in y])
    if cohort == 'C':
        fs = 200
    else:
        fs = 128
    raw = read_raw_edf(f'../data/spindle/Cohort{cohort}/recordings/{cohort}{subject}.edf')
    eeg = raw.get_data('EEG1').squeeze()
    eeg = resample(eeg,86400*500)
    X = torch.from_numpy(eeg.reshape(-1,5000)).float()
    df = pd.read_csv(f'../data/spindle/Cohort{cohort}/scorings/{cohort}{subject}.csv',header=None)
    cat = pd.Categorical(df[1])
    cats = cat.categories
    labels = np.array([[a]*2000 for a in list(cat.codes)]).flatten()
    y = torch.from_numpy(labels.reshape(-1,1000)).mode(dim=1).values
    if f'{cohort}{subject}' in ['D1','D2','D3','C1','C2','C3','C4','C5','C6','C7','C8']:
        # ['1', 'n', 'r', 'w']
        y[torch.where(y == 0)[0]] = 3
        y[torch.where(y == 2)[0]] = 0
        y[torch.where(y == 3)[0]] = 2
    elif f'{cohort}{subject}' in ['D4','D5','D6']:
        # ['n', 'r', 'w']
        y[torch.where(y == 1)[0]] = 3
        y[torch.where(y == 0)[0]] = 1
        y[torch.where(y == 3)[0]] = 0
    elif f'{cohort}{subject}' in ['A2','B1']:
        # ['1', '2', '3', 'a', 'n', 'r', 'w']
        y[torch.where(y == 0)[0]] = 6
        y[torch.where(y == 1)[0]] = 4
        y[torch.where(y == 2)[0]] = 5
        y[torch.where(y == 3)[0]] = 5
        y[torch.where(y == 4)[0]] = 1
        y[torch.where(y == 5)[0]] = 0
        y[torch.where(y == 6)[0]] = 2
    else:
        # ['1', '2', '3', 'n', 'r', 'w']
        y[torch.where(y == 0)[0]] = 5
        y[torch.where(y == 1)[0]] = 3
        y[torch.where(y == 2)[0]] = 4
        y[torch.where(y == 3)[0]] = 1
        y[torch.where(y == 4)[0]] = 0
        y[torch.where(y == 5)[0]] = 2
    y = torch.nn.functional.one_hot(y.long()).float()
    X = torch.cat([zeros(9//2,5000),X,zeros(9//2,5000)])
    return X,y

In [None]:
class Gandalf(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.encoder = Frodo(n_features=5000,device=DEVICE).to(DEVICE)
        self.lstm = nn.LSTM(16,32,bidirectional=True)
        self.fc1 = nn.Linear(64,3)
    def forward(self,x_2d,classification=True):
        x_2d = x_2d.view(-1,9,1,5000)
        x = torch.Tensor().to(DEVICE)
        for t in range(x_2d.size(1)):
            xi = self.encoder(x_2d[:,t,:,:],classification=False)
            x = torch.cat([x,xi.unsqueeze(0)],dim=0)
        out,_ = self.lstm(x)
        if(classification):
            x = self.fc1(out[-1])
        else:
            x = out[-1]
        return x
model = Gandalf()
model.load_state_dict(torch.load(f'../spindle_gandalfs/gandalf_spindle_fold_00/best_model.pt'))
model.to(DEVICE)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
FOLD = 0
ids = ['A1','A2','A3','A4','B1','B2','B3','B4','C1','C2','C3','C4','C5','C6','C7','C8','D1','D2','D3','D4','D5','D6']
test_id = ids[FOLD]
print(test_id)
ids.remove(test_id)

In [None]:
subjects = [load_spindle_eeg_label_pair(cohort=id[0],subject=id[1]) for id in [test_id]]
Xs = [subject[0] for subject in subjects]
ys = [subject[1] for subject in subjects]
devloader = DataLoader(dataset=SSDataset(Xs,ys,range(8640)),batch_size=32,shuffle=False)

In [None]:
loss,report,y_true,y_pred,y_logits = evaluate(devloader,model,criterion,DEVICE)

In [None]:
report

In [None]:
confidences = []
idx = y_logits.argmax(axis=1)
for logit,idx in zip(y_logits,idx):
    confidences.append(logit[idx].item())
fig,axes = plt.subplots(ncols=1,nrows=4,figsize=(8.5,11),dpi=200,gridspec_kw={'height_ratios': [1,1,1,2]})
lower = 1000
upper = 1360

# true label signal
axes[0].plot(y_true[lower:upper],'black')
axes[0].set_yticks(np.arange(3),labels=['P','S','W'])
axes[0].set_ylabel('Stage')
axes[0].text(200,1,'Reference',fontdict={'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'size': 16,
        })
axes[0].spines[['right', 'top']].set_visible(False)
axes[0].margins(.01,.01)

# predicted label signal
axes[1].plot(y_pred[lower:upper],'black')
axes[1].set_yticks(np.arange(3),labels=['P','S','W'])
axes[1].set_ylabel('Stage')
axes[1].text(200,1,'Predicted',fontdict={'family': 'serif',
        'color':  'darkred',
        'weight': 'normal',
        'size': 16,
        })
axes[1].spines[['right', 'top']].set_visible(False)
axes[1].margins(.01,.01)

# eeg signal
axes[2].plot(Xs[0][lower+8:upper+8:5].flatten())
axes[2].set_ylim([-.0002,.0002])
axes[2].margins(0,0)
axes[2].set_ylabel('microVolts')

# predicted probability distribution
axes[3].stackplot(torch.arange(len(y_logits[lower:upper])),y_logits[lower:upper].T)
axes[3].plot(confidences[lower:upper],'black')
axes[3].set_ylabel('Probability')
axes[3].margins(0,0)

plt.show()
plt.savefig('out.svg',bbox_inches='tight')

In [None]:
ids = ['A1','A2','A3','A4','B1','B2','B3','B4','C1','C2','C3','C4','C5','C6','C7','C8','D1','D2','D3','D4','D5','D6']
FOLDS = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]
criterion = torch.nn.CrossEntropyLoss()
all_metrics = pd.DataFrame()
all_metrics_2s = pd.DataFrame()
all_metrics_4s = pd.DataFrame()
all_stats = pd.DataFrame()
all_confidences = []
stage_propotions = pd.DataFrame()
recall_cm = torch.zeros((3,3))
recall_cm_2s = torch.zeros((3,3))
recall_cm_4s = torch.zeros((3,3))
precision_cm = torch.zeros((3,3))
precision_cm_2s = torch.zeros((3,3))
precision_cm_4s = torch.zeros((3,3))
for FOLD in FOLDS:
    test_id = ids[FOLD]
    print(test_id)
    subjects = [load_spindle_eeg_label_pair(cohort=id[0],subject=id[1]) for id in [test_id]]
    Xs = [subject[0] for subject in subjects]
    ys = [subject[1] for subject in subjects]
    testloader = DataLoader(dataset=SSDataset(Xs,ys,range(8640)),batch_size=32,shuffle=False)
    model = Gandalf()
    model.load_state_dict(torch.load(f'../spindle_gandalfs/gandalf_spindle_fold_{FOLD:02d}/best_model.pt',map_location='cpu'))
    model.to(DEVICE)
    loss,report,y_true,y_pred,y_logits = evaluate(testloader,model,criterion,DEVICE)

    subjects = [load_spindle_eeg_label_pair_2s(cohort=id[0],subject=id[1]) for id in [test_id]]
    Xs = [subject[0] for subject in subjects]
    ys = [subject[1] for subject in subjects]
    y_true_2s = ys[0]
    y_true_2s = y_true_2s.argmax(axis=1)
    y_pred_2s = y_pred.repeat(5,1).T.flatten()
    report_2s = metrics(y_true_2s,y_pred_2s)
    y_true_4s = y_true_2s.reshape(-1,2).mode(dim=1).values
    y_pred_4s = y_pred_2s.reshape(-1,2).mode(dim=1).values
    report_4s = metrics(y_true_4s,y_pred_4s)

    confidences = []
    idx = y_logits.argmax(axis=1)
    for logit,idx in zip(y_logits,idx):
        confidences.append(logit[idx].item())
    all_confidences.append(torch.tensor(confidences).mean())
    recall_cm += confusion_matrix(y_true=y_true,y_pred=y_pred,normalize='true')
    recall_cm_2s += confusion_matrix(y_true=y_true_2s,y_pred=y_pred_2s,normalize='true')
    recall_cm_4s += confusion_matrix(y_true=y_true_4s,y_pred=y_pred_4s,normalize='true')
    precision_cm += confusion_matrix(y_true=y_true,y_pred=y_pred,normalize='pred')
    precision_cm_2s += confusion_matrix(y_true=y_true_2s,y_pred=y_pred_2s,normalize='pred')
    precision_cm_4s += confusion_matrix(y_true=y_true_4s,y_pred=y_pred_4s,normalize='pred')
    all_metrics = pd.concat([all_metrics,pd.Series(report,name=f'{FOLD}')],axis=1)
    all_metrics_2s = pd.concat([all_metrics_2s,pd.Series(report_2s,name=f'{FOLD}')],axis=1)
    all_metrics_4s = pd.concat([all_metrics_4s,pd.Series(report_4s,name=f'{FOLD}')],axis=1)
    stage_propotions = pd.concat([stage_propotions,pd.Series((torch.bincount(y_pred.long())/8640).tolist()+['Predicted'],name=f'{FOLD}')],axis=1)
    stage_propotions = pd.concat([stage_propotions,pd.Series((torch.bincount(y_true.long())/8640).tolist()+['Reference'],name=f'{FOLD}')],axis=1)
    df = pd.DataFrame(y_true)
    df.loc[df[0] == 2,0] = 'W'
    df.loc[df[0] == 1,0] = 'S'
    df.loc[df[0] == 0,0] = 'P'
    stats = get_bout_statistics_for_predictions(df[0]).reset_index().melt(id_vars='index')
    stats['type'] = 'Reference'
    all_stats = pd.concat([all_stats,stats],axis=1)
    df = pd.DataFrame(y_pred)
    df.loc[df[0] == 2,0] = 'W'
    df.loc[df[0] == 1,0] = 'S'
    df.loc[df[0] == 0,0] = 'P'
    stats = get_bout_statistics_for_predictions(df[0]).reset_index().melt(id_vars='index')
    stats['type'] = 'Predicted'
    all_stats = pd.concat([all_stats,stats],axis=1)

In [None]:
all_metrics_4s.T.describe()

In [None]:
((2 * recall_cm_4s * precision_cm_4s)/(recall_cm_4s + precision_cm_4s))/len(FOLDS)

In [None]:
len(FOLDS)

In [None]:
all_metrics_2s.T.describe()

In [None]:
c = pd.DataFrame([c.item() for c in all_confidences],columns=['c'])
c['f1'] = all_metrics.T['f1'].to_list()

In [None]:
import seaborn as sns
sns.set_theme('paper')
sns.set_style('whitegrid')
fig,axes = plt.subplots(nrows=2,ncols=2,figsize=(8,8),dpi=200)
sns.boxplot(data=stage_propotions.T.melt(id_vars=3),x='variable',y='value',hue=3,ax=axes[1,1])
sns.boxplot(data=all_metrics.T.melt(),y='value',x='variable',ax=axes[0,0])
sns.heatmap(recall_cm/len(FOLDS),annot=True,fmt='.2f',cbar=False,cmap='Blues',ax=axes[1,0])
sns.regplot(data=c,x='c',y='f1',ax=axes[0,1])

axes[0,0].set_ylim(.5,1)
axes[1,1].set_ylim(-.05,.6)
axes[0,0].set_ylabel('Score')
axes[0,0].set_xlabel('Metric (averaged over n=16 folds)')
axes[1,0].set_xticklabels(labels=['P','S','W'])
axes[1,0].set_yticklabels(labels=['P','S','W'])
axes[1,1].set_xticklabels(labels=['P','S','W'])
axes[0,1].set_xlabel('Average Confidence')
axes[0,1].set_ylabel('macro f1-score')
axes[1,0].set_xlabel('Predicted')
axes[1,0].set_ylabel('Reference')
axes[1,1].set_xlabel('Stage')
axes[1,1].set(ylabel='Proportion of total recording time')
axes[1,1].legend()
axes[0,1].set_xticks([])
axes[0,1].set_yticks([])
plt.margins(x=0,y=0)
plt.savefig('spindle_grid.svg',bbox_inches='tight')

In [None]:
tmp = pd.DataFrame(all_stats.to_numpy().reshape(-1,4))
tmp[2] = tmp[2].astype(float)
fig,axes = plt.subplots(nrows=3,ncols=1,figsize=(8.5,11),dpi=200,gridspec_kw={'height_ratios': [1,1,1]},sharex=True)
sns.violinplot(data=tmp[tmp[1] == 'P'],x=0,y=2,hue=3,split=True,ax=axes[0])
sns.violinplot(data=tmp[tmp[1] == 'S'],x=0,y=2,hue=3,split=True,ax=axes[1])
sns.violinplot(data=tmp[tmp[1] == 'W'],x=0,y=2,hue=3,split=True,ax=axes[2])
axes[0].set_ylabel('Paradoxical (seconds)')
axes[0].set_xlabel('')
axes[0].legend()
axes[1].set_ylabel('Slow-wave (seconds)')
axes[1].set_xlabel('')
axes[1].get_legend().remove()
axes[2].set_ylabel('Wakefulness (seconds)')
axes[2].set_xlabel('')
axes[2].get_legend().remove()

In [None]:
import seaborn as sns
sns.set_theme('paper')
sns.set_style('whitegrid')
fig,axes = plt.subplots(nrows=3,ncols=3,figsize=(9,9),dpi=200)

axes[0,0].set_title('Total Bout Duration')
axes[0,1].set_title('Average Bout Duration')
axes[0,2].set_title('Number of Bouts')
sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'total') & (tmp[1] == 'P')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[0,0])
axes[0,0].set_xlim([0,150])
axes[0,0].set_ylim([0,150])
axes[0,0].set_xlabel('')
axes[0,0].set_ylabel('Reference (paradoxical)')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'total') & (tmp[1] == 'S')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[0,1])

axes[0,1].set_xlim([300,900])
axes[0,1].set_ylim([300,900])
axes[0,1].set_xlabel('')
axes[0,1].set_ylabel('')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'total') & (tmp[1] == 'W')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[0,2])

axes[0,2].set_xlim([500,1200])
axes[0,2].set_ylim([500,1200])
axes[0,2].set_xlabel('')
axes[0,2].set_ylabel('')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'average') & (tmp[1] == 'P')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[1,0])

axes[1,0].set_xlim([0,150])
axes[1,0].set_ylim([0,150])
axes[1,0].set_xlabel('')
axes[1,0].set_ylabel('Reference (slow wave)')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'average') & (tmp[1] == 'S')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[1,1])

axes[1,1].set_xlim([50,350])
axes[1,1].set_ylim([50,350])
axes[1,1].set_xlabel('')
axes[1,1].set_ylabel('')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'average') & (tmp[1] == 'W')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[1,2])

axes[1,2].set_xlim([100,500])
axes[1,2].set_ylim([100,500])
axes[1,2].set_xlabel('')
axes[1,2].set_ylabel('')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'counts') & (tmp[1] == 'P')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[2,0])

axes[2,0].set_xlim([0,160])
axes[2,0].set_ylim([0,160])
axes[2,0].set_xlabel('Predicted')
axes[2,0].set_ylabel('Reference (wakefulness)')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'counts') & (tmp[1] == 'S')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[2,1])

axes[2,1].set_xlim([100,400])
axes[2,1].set_ylim([100,400])
axes[2,1].set_xlabel('Predicted')
axes[2,1].set_ylabel('')

sns.regplot(data=pd.DataFrame(tmp[(tmp[0] == 'counts') & (tmp[1] == 'W')][2].to_numpy().reshape(-1,2),columns=['Reference','Predicted']),x='Reference',y='Predicted',ax=axes[2,2])

axes[2,2].set_xlim([100,400])
axes[2,2].set_ylim([100,400])
axes[2,2].set_xlabel('Predicted')
axes[2,2].set_ylabel('')

plt.margins(x=0,y=0)
plt.savefig('reg.svg',bbox_inches='tight')