# Testing for UNION - DEMO
fistly define the UUID of the experiment to use

In [None]:
experimentId=""

-------------------------------------------------------

In [None]:
#external libraries
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.colors as clt
import plotly
import plotly.subplots as sb
import plotly.express as px
import plotly.graph_objects as go
import dotenv
import pandas as pd
import scipy.fft as fft
import scipy.signal as sg
import scipy.io as sio
import pickle as pkl
import xgboost as xgb
import time
import sklearn.metrics as skm

#project library
from spinco import *

#project variables
demopath=os.getcwd()


## load dreams

In [None]:
samplerate=200  #Should rethink this

In [None]:
signals, annotations, signalsMetadata = loadDREAMSSpindlesDemo(demopath)

In [None]:
annotations.head()

In [None]:
signalsMetadata.head()

## Load experiment results

In [None]:
experimentModels, featureSelection = loadExperiment(experimentId,demopath+'\DREAMS')

In [None]:
experimentModels

In [None]:
#we show the difference in class inbalance for the annotation criteria considered
experimentModels[['criteriumName','spindleTimeRate']].groupby('criteriumName').describe()

In [None]:
featureSelection

## Hyperparameter definition
this should come from a previous evaluation notebook

In [None]:
hyperClose=0.1
hyperDuration=0.3
hyperThres=0.3
hyperDepth=40

## Testing with UNION criterium
we test the optimal points for the prediction threshold and number of boost iterations in the different validation groups

In [None]:
experimentModels=experimentModels[experimentModels.criteriumName=='union'].reset_index(drop=True)
experimentModels

In [None]:
rawF1s=[]
rawPrecisions=[]
rawRecalls=[]

f1s=[]
precisions=[]
recalls=[]

eventF1s=[]
eventPrecisions=[]
eventRecalls=[]

checks=[]

for ind,row in experimentModels.iterrows():
    print('*************************')
    print(str(ind+1)+' of '+str(len(experimentModels)) )
    #load model
    model=loadBooster(row.modelId,experimentId,demopath+'/DREAMS')

    testSubjectId=row.test
    #Define annotations criterium
    usedAnnotations=annotations[annotations.labelerId.isin(row.labelerIdList)].reset_index(drop=True)
    #Load features and labels
    testFeatures=loadFeatureMatrix([testSubjectId],featureSelection,signalsMetadata,samplerate,demopath)
    testLabels=loadLabelsVector([testSubjectId],usedAnnotations,signalsMetadata,samplerate)

    #Predict
    testDMatrix=xgb.DMatrix(data=testFeatures)
    probabilities=model.predict(testDMatrix,iteration_range=(0,hyperDepth))
    rawLabels=probabilities>=hyperThres
    #Raw Metrics
    rawTp=np.sum(rawLabels*testLabels)
    rawFp=np.sum(rawLabels*(1-testLabels))
    rawTn=np.sum((1-rawLabels)*(1-testLabels))
    rawFn=np.sum((1-rawLabels)*testLabels)
    #Raw appends
    rawF1s.append(2*rawTp/(2*rawTp+rawFp+rawFn))
    rawPrecisions.append(rawTp/(rawTp+rawFp) )
    rawRecalls.append(rawTp/(rawTp+rawFn))
    #Process
    processedLabels=labelingProcess(rawLabels,hyperClose,hyperDuration,samplerate)
    #Processed metrics
    tp=np.sum(processedLabels*testLabels)
    fp=np.sum(processedLabels*(1-testLabels))
    tn=np.sum((1-processedLabels)*(1-testLabels))
    fn=np.sum((1-processedLabels)*testLabels)
    #Processed appends
    f1s.append(2*tp/(2*tp+fp+fn))
    precisions.append(tp/(tp+fp))
    recalls.append(tp/(tp+fn))

    #By-event metrics
    processedAnnotations=labelVectorToAnnotations(processedLabels,samplerate)
    gtAnnotations=labelVectorToAnnotations(testLabels,samplerate)   #<- or just filter the annotations
    f,r,p=annotationPairToMetrics(gtAnnotations,processedAnnotations)
    
    #calculate metrics
    eventF1s.append(f)
    eventPrecisions.append(p)
    eventRecalls.append(r)

#include metrics in the dataframe
experimentModels['rawF1']=rawF1s
experimentModels['rawPrecision']=rawPrecisions
experimentModels['rawRecall']=rawRecalls

experimentModels['f1']=f1s
experimentModels['precision']=precisions
experimentModels['recall']=recalls

experimentModels['eventF1']=eventF1s
experimentModels['eventPrecision']=eventPrecisions
experimentModels['eventRecall']=eventRecalls


In [None]:
fig=px.scatter(experimentModels,x='rawF1',y='f1',color='test',hover_name='modelId', marginal_y="histogram")
fig.add_trace(
    go.Scatter(x=experimentModels['rawF1'], y=experimentModels['rawF1'], name="identity", mode='lines',fill="toself")
)
fig.show()

In [None]:
fig=px.scatter(experimentModels,x='rawF1',y='eventF1',color='test',hover_name='modelId', marginal_y="histogram")
fig.add_trace(
    go.Scatter(x=experimentModels['rawF1'], y=experimentModels['rawF1'], name="identity", mode='lines',fill="toself")
)
fig.show()

In [None]:
fig=px.scatter(experimentModels,x='eventF1',y='eventPrecision',color='test',hover_name='modelId', marginal_y="histogram")
fig.add_trace(
    go.Scatter(x=experimentModels['eventF1'], y=experimentModels['eventF1'], name="identity", mode='lines',fill="toself")
)
fig.show()

In [None]:
fig=px.scatter(experimentModels,x='eventF1',y='eventRecall',color='test',hover_name='modelId', marginal_y="histogram")
fig.add_trace(
    go.Scatter(x=experimentModels['eventF1'], y=experimentModels['eventF1'], name="identity", mode='lines',fill="toself")
)
fig.show()

In [None]:
experimentModels[['test','eventF1','eventPrecision','eventRecall']].groupby('test').describe(percentiles=[0.5])

### event metrics

In [None]:
experimentModels[['test','eventF1','eventPrecision','eventRecall']].groupby('test',as_index=False).mean()

In [None]:
experimentModels[['test','eventF1','eventPrecision','eventRecall']].groupby('test',as_index=True).mean().mean()

In [None]:
experimentModels[['test','eventF1','eventPrecision','eventRecall']].groupby('test',as_index=True).mean().std()

### sample metrics

In [None]:
experimentModels[['test','f1','precision','recall']].groupby('test',as_index=False).mean()

In [None]:
experimentModels[['test','f1','precision','recall']].groupby('test',as_index=True).mean().mean()

In [None]:
experimentModels[['test','f1','precision','recall']].groupby('test',as_index=True).mean().std()

In [None]:
auxPrecision=pd.DataFrame({
    'metric':'event precision',
    'value':experimentModels.eventPrecision,
    'event F1':experimentModels.eventF1
})

auxRecall=pd.DataFrame({
    'metric':'event recall',
    'value':experimentModels.eventRecall,
    'event F1':experimentModels.eventF1
})
visualTradeoff=pd.concat((auxPrecision,auxRecall))

In [None]:
fig=px.scatter(visualTradeoff,x='event F1',y='value',color='metric', marginal_y="histogram")
fig.add_trace(
    go.Scatter(x=experimentModels['eventF1'], y=experimentModels['eventF1'], name="identity", mode='lines',fill="toself")
)
fig.show()

### CODA: by-event evaluation summary of the last fold testing

In [None]:
annotationPairToGraph(gtAnnotations,processedAnnotations)

congratulations for reaching this point, enjoy!