In [3]:
import json
import numpy as np
import os
import random
import pandas as pd
import pickle
import seaborn as sns
import typing 

from IPython.display import clear_output
from misc import _get_usleep_token
from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
from pprint import pprint
from scipy import io, special, stats
from sklearn import metrics, model_selection
from statistics import mode
from tqdm import tqdm
from usleep_api import USleepAPI


In [120]:
# GLOBALS 

MS_MAPPING = {"Wake": 0, "MS": 1}
AASM_MAPPING = {"Wake": 0, "N1": 1, "N2": 2, "N3": 3, "REM": 4}

float_formatter = "{:.2f}".format
np.set_printoptions(formatter={'float_kind':float_formatter})

%matplotlib widget

## Notebook functions
1. Helper functions
2. Plotting functions
3. BernLabels class for handling labelled data from BERN group


In [121]:
# Helper functions 

def get_probs(file):
    out = np.load(file)
    probs = special.softmax(out, 1)
    return probs

def aasm_to_wake_sleep(aasm: np.array):
    out = np.zeros(aasm.shape)
    out[aasm > 0] = 1
    return out


def resample_usleep_preds(y_pred: np.array, data_per_pred: int, org_fs: int = 200, usleep_fs: int = 128):
    return np.repeat(y_pred, np.floor(data_per_pred * (org_fs / usleep_fs)),0) 


def load_pickle_from_file(file):
    with open(file, 'rb') as handle:
        return pickle.load(handle)
    
def write_to_pickle_file(obj, file):
    with open(file, 'wb') as handle:
        pickle.dump(obj, handle, protocol=pickle.HIGHEST_PROTOCOL)

def _find_singles(y,idx):
    
    pidx = idx-1
    if pidx[0] < 0:
        prev = y[pidx[1:]]
        prev = np.append(0,prev)
    else:
        prev = y[pidx]

    nidx = idx+1
    if nidx[-1] >= len(y):
        nxt = y[nidx[:-1]]
        nxt = np.append(nxt, 0)
    else:
        nxt = y[nidx]

    single_idx = np.logical_not(prev) * np.logical_not(nxt)
    singles = idx[single_idx]
    return singles


def get_target_label_start_and_stop_indices(labels, target):

    # Find indices of the target in labels
    idx = np.where(labels==target)[0]

    # Calculate where there are islands of target in labels
    islands = np.diff(idx)==1
    ffy = np.pad(islands, pad_width=(1, 1), mode="constant", constant_values=(0,0)).astype(int)

    # Calculate where the islands start (>0) and end (<0) by calculating the differnece of the island indices
    diff_idx = np.diff(ffy)
    
    # Find where they start and end
    start_idx = np.where(diff_idx > 0)[0]
    stop_idx = np.where(diff_idx < 0)[0]

    start = idx[start_idx]
    stop = idx[stop_idx] 
    
    # Also find single targets
    if np.any(idx):
        singles = _find_singles(labels, idx)
    else:
        singles = np.empty([0])
        
    return start, stop, singles

def remove_invalid_labels(labels, target = 1, min_duration = 3, max_duration = 15, fs = 1, verbose = True):
    
    start, stop, singles = get_target_label_start_and_stop_indices(labels, target)
    target_time = (stop - start) / fs
    
    too_short = np.where(target_time < min_duration)[0]
    too_long = np.where(target_time > max_duration)[0]
    
    unit = "samples" if fs == 1 else "seconds"
    
    if verbose:
        print(f"{len(too_short)} labels are shorter than {min_duration} {unit}\n"
              f"{len(too_long)} labels are longer than {max_duration} {unit}\n")

    invalid_idx = np.hstack([too_short, too_long])
    
    fixed_labels = labels
    for i in invalid_idx:
        fixed_labels[start[i]:stop[i]+1] = 0 
    
    if np.any(singles):
        if verbose:
            print(f"{len(singles)} singletons were found and removed") 
        fixed_labels[singles] = 0
        
    
    return fixed_labels

#def _fill_label_gaps(labels, target, limit, fs = 1):
#    
#    start, stop = get_target_label_start_and_stop_indices(labels, target)
#    label_gap = (start[1:] - stop[0:-1]) / fs
#    idx = np.where(ms_gap <= limit)[0]

#    filled_gaps = np.copy(labels)
#    for i in idx:
#        filled_gaps[stop[i]+1:start[i+1]] = 1
#    
#    return filled_gaps

def rolling_window(array, window_size,freq):
    shape = (array.shape[0] - window_size + 1, window_size)
    strides = (array.strides[0],) + array.strides
    rolled = np.lib.stride_tricks.as_strided(array, shape=shape, strides=strides)
    return rolled[np.arange(0,shape[0],freq)]

def plot_roc_curve(y_true, y_probs, pos_label=1, ax = None):
    
    if ax is None:
        _, ax = plt.subplots()
    
    fpr, tpr, thresholds = metrics.roc_curve(y_true, y_probs[:,pos_label], pos_label = pos_label)
    
    ax.plot([0,1],[0,1])
    ax.plot(fpr, tpr)
    ax.set_xlabel("False Positive Rate (FPR)")
    ax.set_ylabel("True Positive Rate (TPR)")
 
    return fpr, tpr, ax


def plot_precision_recall_curve(y_true, y_probs, pos_label=1, ax = None):
    if ax is None:
        _, ax = plt.subplots()
    
    precision, recall, thresholds = metrics.precision_recall_curve(y_true, y_probs[:,pos_label], pos_label = pos_label)
    
    if ax is not None:

        baseline = np.sum(y_true == pos_label) / len(y_true) 
        ax.plot([0,1],[baseline, baseline], linestyle='--')
        ax.plot(recall, precision)
        ax.set_xlabel("Recall")
        ax.set_ylabel("Precision")

    return precision, recall, ax



def compute_performance_metrics(y_true, y_pred, y_probs,
                                labels = [0,1], classes = ["Wake","MS"],
                                minority_label = 1, pos_label = 1, plot_on = True):
    
    if plot_on:
        fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=[9,4])
    else:
        ax1 = None
        ax2 = None
    
    # Get general classification report
    report = metrics.classification_report(y_true, y_pred, labels=labels, target_names=classes,
                                           output_dict=True)
     
    # Calculate ROC metics    
    roc_fpr, roc_tpr, roc_ax = plot_roc_curve(y_true, y_probs, pos_label, ax1)
    roc_auc = metrics.roc_auc_score(y_true, y_probs[:,pos_label])
    
    # Calculate PR curve metrics
    _precision, _recall, pr_ax = plot_precision_recall_curve(y_true, y_probs, pos_label, ax2)
    pr_auc = metrics.auc(_recall, _precision)
    
    
    # Compute Matthews correlation coefficent
    mcc = metrics.matthews_corrcoef(y_true, y_pred)
    
    # Compute Cohen's Kappa
    cohen_kappa = metrics.cohen_kappa_score(y_true, y_pred)
    
    # Store results in report
    report["roc_auc"]     = roc_auc
    report["pr_auc"]      = pr_auc
    report["mcc"]         = mcc
    report["cohen_kappa"] = cohen_kappa
    
    # A bit extra spicy
    if plot_on:
        ax1.annotate(f"ROC AUC = {roc_auc:.2f}", (0.7,0.2), fontsize=8)
        ax2.annotate(f"PR AUC = {pr_auc:.2f}", (0.7,0.7), fontsize=8)
    
    return report


In [122]:
# Plotting functions

def format_ax(ax, labs):
    ax.set_xlabel("Period number")
    ax.set_ylabel("Sleep stage")
    ax.set_yticks(range(len(labs)))
    ax.set_yticklabels(labs)
    ax.invert_yaxis()
    line = ax.lines[0]
    ids = line.get_xdata()
    ax.set_xlim(1, ids[-1]+1)
    l = ax.legend(loc=3)
    l.get_frame().set_linewidth(0)

def ghost_poly(face_color=[1,1,1], edge_color=[0,0,0]):
    _xy = np.empty([4,2])
    _xy[:] = np.nan
    _poly = Polygon(_xy, facecolor=face_color, edgecolor=edge_color)
    return _poly

def plot_probs(ax, probs: np.array, labs, fs = 1):

    av = np.cumsum(probs, axis=1)
    c = sns.color_palette("tab10", len(labs)-1)

    # Create 'ghost' patch for 'Wake'
    _poly = ghost_poly()
    ax.add_patch(_poly)

    for i in range(probs.shape[1]-1):
        xy = np.zeros([av.shape[0] * 2, 2]) 
        xy[:av.shape[0], 0] = np.arange(av.shape[0]) / fs
        xy[av.shape[0]:, 0] = np.flip(np.arange(av.shape[0]), axis=0) / fs
        xy[:av.shape[0], 1] = av[:, i]
        xy[av.shape[0]:, 1] = np.flip(av[:, i+1], axis=0)

        poly = Polygon(xy, facecolor=c[i], edgecolor=None)
        ax.add_patch(poly)

    ax.set_ylabel("Probability")
    ax.legend(labs, loc='lower left')
    ax.set_xlim([0, av.shape[0]])

def plot_label_patches(labels, mapping, y_min=0, y_max=1, ax = None, fs = 1):
    
    if ax is None:
        _, ax = plt.subplots()
    
    def create_patches(ax, start, stop, cc, y_min=0, y_max=1, fs = 1):
        for start_idx, stop_idx in zip(start, stop):

            x = np.arange(start_idx, stop_idx+1) / fs
            xn = len(x)
            xy = np.zeros([xn * 2, 2])
            xy[:xn, 0] = x
            xy[xn:, 0] = np.flip(x, axis=0)

            xy[:xn, 1] = np.ones([xn]) * y_min
            xy[xn:, 1] = np.ones([xn]) * y_max
            poly = Polygon(xy, facecolor=cc, edgecolor=None, alpha=1)
            patch = ax.add_patch(poly)

        return patch
    
    handles = []
    handle_names = []
    colors = sns.color_palette("tab10", len(mapping.keys())-1)
    for label_name, label in mapping.items():
        
        if label == 0 or label_name == "Wake":
            _poly = ghost_poly()
            handle = ax.add_patch(_poly)
            handles.append(handle)
            handle_names.append(label_name)
            continue
        
        start, stop, singles = get_target_label_start_and_stop_indices(labels, label)
        # Convert to time axis with fs
        start = start 
        stop = stop 
        if not np.any(start):
            print("No start")
            continue
        else:
            handle = create_patches(ax, start, stop, colors[label-1], y_min, y_max, fs = fs)
            handles.append(handle)
            handle_names.append(label_name)
            
        if np.any(singles):
            for x in singles:
                _ = ax.axvline(x=x/fs, ymin=y_min, ymax=y_max, color=colors[label-1])
        
    return handles, handle_names


def MatplotlibClearMemory():
    allfignums = plt.get_fignums()
    for i in allfignums:
        fig = plt.figure(i)
        fig.clear()
        plt.close( fig )


In [123]:
class BernLabels(object):
    
    fs            = 200
    folder        = "labels"
    raw_mapping   =  {"Wake": 0,
                      "MSE": 1,
                      "MSEc": 2,
                      "ED": 3}
    
    min_duration = 3
    max_duration = 15
    
    
    def __init__(self, file: str, mapping: dict, include_unilateral = False):
        
        # File settings and read
        self.file = file
        self.path = os.path.join(self.folder, file) if not os.path.exists(self.file) else file
        self.__raw = io.loadmat(self.path)
        
        # Disseminate raw data
        self.raw_O1 = np.squeeze(self.__raw['labels']['O1'][0][0])
        self.raw_O2 = np.squeeze(self.__raw['labels']['O2'][0][0])
        self.raw_labels = np.vstack([self.raw_O1, self.raw_O2])
        self.num_labels = self.raw_labels.shape[1]
        self.time = np.arange(0, (self.num_labels)/self.fs, 1/self.fs)
        
        # Apply mapping (i.e. convert Bern to Wake vs Sleep)
        self.mapping = mapping
        self.include_unilateral = include_unilateral
        self.label_mapping = {k: i for i, k in enumerate(self.mapping.keys())}
        self.convert_labels(mapping, include_unilateral)
        
        
    def __repr__(self): 
        cls = self.__class__.__name__
        return f"{cls}(file='{self.file}', mapping={self.mapping}, include_bilateral={self.include_unilateral})"
    
    def append(self, other):
        
        assert self.mapping == other.mapping
        assert self.include_unilateral == other.include_unilateral
        
        self.file = [self.file, other.file]
        self.__raw = [self.__raw, other.__raw]
        func = lambda x, y: np.hstack([x,y])
        self.raw_O1 = func(self.raw_O1, other.raw_O2)
        self.raw_O2 = func(self.raw_O2, other.raw_O2)
        self.raw_labels = func(self.raw_labels, other.raw_labels)
        self.num_labels = self.num_labels + other.num_labels
        self.time = np.arange(0, (self.num_labels)/self.fs, 1/self.fs) #Incorrect but needed for plotting hypnograms
        
        self.labels = func(self.labels, other.labels)
        
        return self
    
    def convert_labels(self, mapping = dict, include_unilateral = False):
        
        print_on = False
        
        self.labels = np.empty(self.raw_O1.shape)
        
        if include_unilateral: 
            func = lambda x: np.any(x, 0)
        else:
            func = lambda x: np.all(x, 0)
            
        wake_idx = func(np.isin(self.raw_labels, mapping["Wake"])) 
        sleep_idx = func(np.isin(self.raw_labels, mapping["MS"]))
        
        total_idx = (np.sum(wake_idx)+np.sum(sleep_idx))
        
        if (((total_idx) != self.num_labels) and print_on):
            print(f"Warning! {self.num_labels - total_idx} labels disagree... \nWill be overwritten as Wake")
            print(self.raw_labels[:,self.raw_O1 != self.raw_O2].T)
        
        self.labels[wake_idx] = self.label_mapping["Wake"]
        self.labels[sleep_idx] = self.label_mapping["MS"]
        
        self.mapping = mapping
        self.include_unilateral = include_unilateral
        
        return
    
    def apply_time_critera(self, min_duration = 3, max_duration = 15, replace = True):
        
        ms_index = self.label_mapping["MS"]
        copy_labels = np.copy(self.labels)
        fixed_labels = remove_invalid_labels(copy_labels, target=ms_index,
                                             min_duration = min_duration, max_duration = max_duration,
                                             fs = self.fs)
        
        if replace:
            print("Overwriting labels!")
            self.labels = fixed_labels
            return
        else:
            return fixed_labels
        
    def apply_rolling_func(self, win = 0.2, step = 0.2, func=None, replace = False):
        y = np.copy(self.labels)
        win_samples = int(win * self.fs)
        step_samples = int(step * self.fs)
        arr = np.array(rolling_window(y, win_samples, step_samples))
        
        if func is None:
            y_func = np.median(arr,1)
        else:
            y_func = func(arr)
            
        if replace:
            self.labels = y_func
            self.prev_fs = self.fs
            self.fs = 1/step
        else:
            return (y_func, rolling_window(self.time, win_samples, step_samples))
    
    def get_labels_per_num_seconds(self, num, func = lambda x: mode(x)):
    
        samps_per_label = self.fs * num
        labels_reshaped = np.reshape(self.labels, [int(len(self.labels)/samps_per_label),samps_per_label])
        tmp = [func(x) for x in b]
        tmp = np.array(tmp)
        return np.repeat(c,self.fs*num)
    
    def plot_raw_labels(self, ax = None, as_hypnogram = False, fs=200):
        
        
        if ax is None:
            fig, ax = plt.subplots()

        if as_hypnogram:
            ids = np.arange(self.num_labels)
            ax.step(self.time, self.raw_O1, "-", linewidth=1, color="darkred", label="Raw [O1]")
            ax.step(self.time, self.raw_O2, "--", linewidth=1, color="darkgray", label="Raw [O2]")
            ax.legend()
            format_ax(ax, self.raw_mapping.keys())
        
        else:
            _, _ = plot_label_patches(labels=self.raw_O1, mapping=self.raw_mapping, fs=fs,
                                                         y_min=0, y_max=0.5, ax=ax)
            _, _ = plot_label_patches(labels=self.raw_O2, mapping=self.raw_mapping, fs=fs,
                                                         y_min=0.5, y_max=1, ax=ax,)
            ax.autoscale(enable=True, axis = "both", tight = True)
            ax.plot([*ax.get_xlim()],[0.5, 0.5],'k-',linewidth=0.5)
            handles = self.__proxy_legend(ax)
            ax.legend(handles, self.raw_mapping.keys())
            ax.set_yticks([0.25, 0.75])            
            ax.set_yticklabels(["Raw O1","Raw O2"])
            ax.set_ylim([0, 1])
        
        xlab = "Time [s]" if fs == self.fs else "Sample #"
        ax.set_xlabel(xlab)
        ax.autoscale(enable=True, axis = "both", tight = True)
        return ax
            
        
    def plot_labels(self, ax = None, as_hypnogram = False, fs = 200):
        
        if ax is None:
            fig, ax = plt.subplots()
        
        if as_hypnogram:
            ids = np.arange(self.num_labels)
            ax.step(self.time, self.labels, "-", linewidth = 1, color = "darkblue", label="New Labels")
            ax.legend()
            format_ax(ax, self.mapping.keys())
            ax.set_xlabel("Time [s]")
        
        else:
            hdl, names = plot_label_patches(labels=self.labels, mapping=self.label_mapping, ax=ax, fs = fs)
            ax.legend(hdl, names, loc="upper left")
            ax.set_yticks([0.5])
            ax.set_yticklabels(["New labels"])
        
        xlab = "Time [s]" if fs == self.fs else "Sample #"
        ax.set_xlabel(xlab)
        ax.autoscale(enable=True, axis = "both", tight = True)
        return ax
        
        
    def __proxy_legend(self, ax):
        _colors = sns.color_palette("tab10", len(self.raw_mapping.keys())-1)
        _handles = [ax.add_patch(ghost_poly())]
        for i in range(len(self.raw_mapping.keys())-1):
            _poly = ghost_poly(face_color=_colors[i], edge_color=[1,1,1])
            _patch = ax.add_patch(_poly)
            _handles.append(_patch)
        return _handles


## Subset data
Select a subset (n = 5) of data to use for pre-liminary analysis of microsleep predictions using a fully trained U-Sleep model.
Set the seed to 42

In [146]:
all_edf_files = np.array(sorted(os.listdir("edf_data/")))
all_labels_files = np.array(sorted(os.listdir("labels/")))
all_names = [x.replace(".edf","") for x in all_edf_files]
assert [x.replace(".edf", "") for x in all_edf_files] == [x.replace(".mat","") for x in all_labels_files]

use_bern_splits = True

if use_bern_splits:
    with open("skorucack_splits.json","r") as f:
        splits = json.loads(f.read())

    train_names = np.array(splits['train'])
    train_idx = np.isin(all_names, train_names)
    test_names = np.array(splits['test'])
    test_idx = np.isin(all_names, val_names)
    test_edf_files = all_edf_files[test_idx]
    test_label_files = all_labels_files[test_idx]
else:    
    random.seed(42)
    n_files = 5
    train_idx = random.sample(range(0, len(all_edf_files)), n_files)


edf_files = all_edf_files[train_idx]
label_files = all_labels_files[train_idx]
names = [x.replace(".edf", "") for x in edf_files]

## Preliminary analysis
* Run all pre-liminary recordings through U-Sleep API
* Read and concatinate output files
* Compute statistics
    * ROC Curves and ROC AUC 
    * PR Curves and PR AUC
    * Cohen's kappa

In [None]:
from ut_commands import predict_one
input_dir              = "edf_data"
output_dir             = "predictions"

out_files = []
samples_per = [1,2,4,8,16]
for s in samples_per:
    hz = int(128/s)
    for edf in all_edf_files:
        name = edf.replace(".edf","")
        inp_file = os.path.join(input_dir, edf)
        out_file = os.path.join(output_dir, f"{hz}_hz", f"{name}.npy")
        out_files.append(out_file)

        #predict_one(input_file = inp_file, output_file=out_file, data_per_prediction=s)
        clear_output(wait=True)


In [None]:


# Set some parameters
model                  = "U-Sleep v1.0"
input_dir              = "edf_data"
output_dir             = "predictions"
predictions_per_second = 1
channel_groups         = [['EEG O1-M2', 'EOG LOC-M1'],
                          ['EEG O2-M1', 'EOG LOC-M1'],
                          ['EEG O1-M2', 'EOG ROC-M1'],
                          ['EEG O2-M1', 'EOG ROC-M1']]
#try:
#    api = USleepAPI(api_token=os.environ["USLEEP_API_TOKEN"])
#except:
#    api = USleepAPI(api_token=_get_usleep_token())

out_files = []
pred_window = str(int(predictions_per_second**-1)) + "sec"
for edf in edf_files:
    
    name = edf.replace(".edf","")
    
    inp_file = os.path.join(input_dir, edf)
    out_file = os.path.join(output_dir, f"{predictions_per_second}_hz",f"{name}.npy")
    print(f"Running: {out_file}")
    
    #api.quick_predict(
    #     input_file_path=inp_file,
    #     model = model,
    #     output_file_path=out_file,
    #     data_per_prediction=128 * (1/predictions_per_second),
    #     channel_groups=channel_groups,
    #     with_confidence_scores=True
    # )

    out_files.append(out_file)


In [None]:

y_true = np.empty(shape=[0])
y_probs = np.empty(shape=[0,5])

mapping = {"Wake": [0, 2, 3], "MS": [1]}
uni_on = True
replace = False
pred_freq = 1

for idx, rec in enumerate(out_files):
    
    probs = get_probs(rec)
    probs_resampled = resample_usleep_preds(probs, data_per_pred = 1)#128 * (1/predictions_per_second))
    y_probs = np.concatenate([y_probs, probs_resampled])
    #y_probs = np.concatenate([y_probs, probs])
    
    if idx == 0:
        BERN = BernLabels(label_files[idx], mapping, uni_on)
        BERN.apply_time_critera(replace=replace)
    else:
        tmp = BernLabels(label_files[idx], mapping, uni_on)
        tmp.apply_time_critera(replace=replace)
        BERN.append(tmp)
    
    #assert BERN.num_labels == len(y_probs)
    clear_output(wait=True)
    
y_pred = np.argmax(y_probs,1)
y_pred = aasm_to_wake_sleep(y_pred)

# remove y_pred where outside MS definition
#_y_pred = remove_invalid_labels(y_pred, 1, fs=200)

y_true = BERN.labels

# Add predictions using cumulative sleep probabilites 
yy_probs = np.empty([y_probs.shape[0],2])
yy_probs[:,0] = y_probs[:,0]
yy_probs[:,1] = 1-y_probs[:,0]#np.sum(y_probs[:,1:],1)
yy_pred = np.argmax(yy_probs,1)
#_yy_pred = remove_invalid_labels(yy_pred, 1, fs=200)


In [None]:
# Highest sleep class vs wake
y_bin = np.empty([y_probs.shape[0],2])
y_bin[:,0] = y_probs[:,0]
high_sleep_idx = np.argmax(y_probs[:,1:5],1) + 1
for val in np.unique(high_sleep_idx):
    idx = np.where(high_sleep_idx == val)
    y_bin[idx,1] = y_probs[idx,val]

yb_pred = np.argmax(y_bin,1)

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(nrows=3, ncols=1, figsize=[9,9], sharex=True)
lab_fs = 200
pred_fs = 1

if lab_fs == 200:
    xlab = "Time [s]"
else:
    xlab = "Sample #"
    
# Plot hypnodensity
plot_probs(ax1, y_probs, ["Wake", "N1", "N2", "N3", "REM"], fs = pred_fs)
ax1.autoscale(enable=True, axis='both', tight=True)

# Plot the predicted hypnogram
_,_=plot_label_patches(y_pred, {"Wake": 0, "MS": 1}, y_min=2/3, ax=ax2, fs= pred_fs)
_,_=plot_label_patches(yb_pred, {"Wake": 0, "MS": 1}, y_min=1/3, y_max=2/3, ax=ax2, fs = pred_fs)
hdl,names=plot_label_patches(yy_pred, {"Wake": 0, "MS": 1}, y_max=1/3, ax=ax2, fs= pred_fs)

ax2.legend(hdl, names, loc=1)
ax2.autoscale(enable=True, axis='both', tight=True)
ax2.plot([*ax2.get_xlim()],[1/3, 1/3],'-k',linewidth=0.5)
ax2.plot([*ax2.get_xlim()],[2/3, 2/3],'-k',linewidth=0.5)
ax2.set_yticks([0.5/3, 1.5/3, 2.5/3])
ax2.set_yticklabels(["yy_pred","yb_pred","y_pred"])

# Plot the true hypnogram
hdl, names = plot_label_patches(y_true, {"Wake": 0, "MS": 1}, ax = ax3, fs = lab_fs)
ax3.autoscale(enable=True, axis='both', tight=True)
ax3.legend(hdl, names, loc = 1)
ax3.set_xlabel(xlab)
clear_output(wait=False)
#plt.savefig(fig_path("overall"), format="png")

In [None]:
bp, br, bf, bt = calc_stats(TP,FP,FN)
out = {"Precision": bp, "Recall": br, "F1-Score": bf, "Threshold": bt}
pprint(out)

In [None]:
preds = 1*(yy_probs[:,1]>bt)
fig, (ax, ax1) = plt.subplots(nrows=2,ncols=1,sharex=True)
h,l=plot_label_patches(preds,{"Wake":0, "MS":1},ax=ax, fs = 200)
ax.legend(h,l)
h,l=plot_label_patches(y_true, {"Wake":0, "MS":1},ax=ax1, fs = 200)
ax1.legend(h,l,loc=1)
ax1.autoscale(enable=True, axis="x", tight=True)
[x.set_yticks([]) for x in [ax,ax1]]


## U-Sleep with BERN manuscript
The BERN scoring criteria is more detailed and conservative than other methods used to identify microsleep episodes (MSE), therfore it deviates from our common definition in numerous ways:

__3 Classes of Microsleep__
1. _MSE_: Analogous to the common definiton of MSE, but including a requirement of >80% the interval with closed eyes (from videography). This class is additionally scored either unilaterally or bilaterally. This class has a strict and conservative scoring definition.
2. _MSEc_: Shares similar features as MSE and has high likelihood of being an MSE, but does not fulfill all criteria of an MSE. Less conservative than MSE.
3. _ED_: There is not a clear definition of change in frequency in EEG, quick alterations of regular and irregular activity and morphology of the EEG. When MSE and MSEc criteria are not fulfilled but the episode does not resemble wakefulness. 

__Duration = 1 to 15+ seconds__
+ They define MSE shorter than the 3 second common minimum duration, and reported more than 40% of their scored MSEs shorter than 3 seconds.
+ They also define some MSE episodes longer than 15 seconds in rare cases where an AASM sleep is not scored due to the MSE spanning over two epochs.


### Mitigation tactics:
__Three Classes to MS vs Wake:__
* Bern Labels
    + We can merge the BERN labels to MS or Wake as follows
    | Option  | MS              | Wake             |   |   |
    |---------|-----------------|------------------|---|---|
    | 1       | {MSE}           | {MSEc, ED, Wake} |   |   |
    | 2       | {MSE, MSEc}     | {ED, Wake}       |   |   |
    | 3       | {MSE, MSEc, ED} | {ED}             |   |   |
    + Options 1. to 3. omit the notion of unilateral events and scores them simply as an event
    + In recent automatic classification work by the Bern group, Option 1 showed the highest classification problem and on multiple accounts they state that MSEc is more closely related to MSE and ED with Wake.

* U-Sleep output
We convert the AASM scoring to MS by
    + Choosing the AASM class probability with the highest value, being scored as MS if it is a sleep stage otherwise Wake.
    + Thresholding the highest AASM sleep probability {N1, N2, N3 or REM} otherwise as Wake.
    + Combining the probabilites of the AASM sleep stages as probabilites of MS and then threholding the cominbed sleep probabilites.

__Duration:__
* BERN Labels
    + Any MS event longer than 15 seconds should be __omitted__ (as it is not a commonly defined MS)
    + Combine MS events that are shorter than 3 seconds (MS<3) where the duration from the preceding MS<3 onset is $\leq$ 3 seconds from the next MS<3.
         + _This must be followed with an assertion that the concatenation does not generate MS events > 15 seconds_

* U-Sleep output
    + Any MS event longer than 15 seconds should be __omitted__ (as it is not a commonly defined MS)
    + ??? _Should the same criteria be applied to MS<3_ ???
        + _We aim to detect microsleep, so it should be enoguh to know that it occured rather than perfectly overlapping it. So by applying the same criteria as above we are interferring with the predictions and extrapolating?_
        
### Evaluation:
The performance of the predictions will be evaluated analogous to A. Brink-Kjaer et al. (2020)
* MS predictions overlapping the true MS events are considered __true positives__
* MS predictions not overlapping the true MS events are considered __false positivies__
* Wake predictions overlapping true MS eveents are considerded __false negatives__

The thresholding will be perfomed using precision recall curves for different thresholds and an F1-score on a training (dev) set and the optimal threshold applied to a validation (test) set. We will use the training split from Malafeev et al. 2020 as dev and validation split as test, to preserve the test for later analysis in the project.

The evalaution will be made by concatenating the predictions and true labels of all the recordings

Additionally, the overlap metrics can be further quantified as Iou (intersection over union) or average precision.
https://medium.com/@timothycarlen/understanding-the-map-evaluation-metric-for-object-detection-a07fe6962cf3

A per sample evaluation can also be made, however this would require brutly upsampling the U-sleep predictions to match the labels (200 Hz), and would be a more harsh evaluation metric.


## Manuscript for preliminary U-Sleep evaluation on Bern data
The comparison will be made with the results from Skorucack et al., 2020 (RF, SVM, LSTM classifiers).
The evaluation used in the reference study:
- bMSE vs Wake (also some other classification problems e.g. bMSE vs {Wake, uMSE, uMSEc, uED})
- True labels were converted from 200 Hz to 5 Hz (200 ms resolution) using a 9 second median filter with 200 ms step size.
- Predictions from the RF and SVM were converted to 5 Hz resolution using same median filter.
- MS predictions after resampling which were shorter than 1 second were excluded
- Calculate sensitivity, specificity, accuracy, precision, cohen's kappa.
    
In this analysis, we will use the same dev/test split (53/23) and adapt the U-Sleep output to fit the same evaluation scheme as used in the reference paper.
The hyperparameters of the U-Sleep model are:
* Data per prediction (prediction rate): 1, 2, 4, 8, 16
* Post-procesing of probabilities: y_argmax (not w/ tunable threshold), y_max_sleep, and y_sum_sleep
* MS Threshold: 0.025:0.025:1.0

Therefore, making 5x3 = 15 models.
Since the U-Sleep model is pre-trained, the only "training" part is the threshold tuning. The optimal threshold will be determined by the highest f1-score analogous to Brink-Kjær. The optimal model will be found by using a 5-fold CV validation where a model will be trained (tune threshold) on K-1 folds and validated against the remaining fold. The model with the highest f1-score will be chosen and re-trained on the entire dev set before evaluating it on the test set.




In [125]:
def get_all_probs(rec):
    probs = get_probs(rec)
    probs_sum = np.column_stack([probs[:,0], np.sum(probs[:,1:5],axis=1)])
    probs_max = np.column_stack([probs[:,0], np.max(probs[:,1:5],axis=1)])
    
    return probs, probs_sum, probs_max


def psuedo_resample(y_org, first_last):
    if len(y_org.shape) > 1:
        return np.array([np.median(y_org[:,x[0]:x[1]],1) for x in first_last]).T
    else:
        return np.array([np.median(y_org[x[0]:x[1]]) for x in first_last])
    
make_first_last = lambda time_pos, hz: np.array([[np.floor(x[0]*hz), np.ceil(x[-1]*hz)+1] for x in time_pos], dtype=int)

_map = {"Wake": [0,2,3], "MS": [1]}
_uni = False

# thresholds
# Initialization
tstep = 0.025
tstart = 0.025
tmax = 1.0
tnum = ((tmax - tstart) / tstep) + 1
thresholds = np.linspace(tstart,1.0,np.round(tnum).astype(int))

In [148]:
## PROCESSING PRELIM ANALYSIS

HZ = [8, 16, 32, 64, 128]
for hz in HZ:
    resampled_labels = dict.fromkeys(all_names)
    resampled_first_last = dict.fromkeys(all_names)
    entries = []
        
    print(f"Dataframe creation for {hz} Hz")
    for edf, lab in zip(all_edf_files, all_labels_files):

        _id = edf.replace(".edf","")
        _type = "train" if _id in splits["train"] else "test"

        _edf = os.path.join("edf_data",edf)
        _labels = os.path.join("labels",lab)
        _preds = os.path.join("predictions", f"{hz}_hz",f"{_id}.npy")

        _tmp = BernLabels(lab, _map, _uni)
        _any_ms = np.sum(_tmp.labels) > 1

        _ms_200, _time_pos = _tmp.apply_rolling_func(win=0.2, step=0.2)
        _ms_200[_ms_200 == 0.5] == 1
        _any_ms_200 = np.sum(_ms_200) > 1

        entry = {"type": _type, "id": _id, "edf": _edf, "labels": _labels, "preds": _preds,
                 "ms": _any_ms, "ms_200": _any_ms_200}
        entries.append(entry)

        fixed_resampled_labels = _ms_200
        fixed_resampled_labels[fixed_resampled_labels==0.5] == 1
        
        resampled_labels[_id] = np.array(fixed_resampled_labels, dtype=int)
        resampled_first_last[_id] = make_first_last(_time_pos, hz)

    df = pd.DataFrame.from_records(entries)
    df.to_csv(f"prelim_data/corrected_{hz}_info_df.csv")

#     processed_recs = dict.fromkeys(all_names)
#     print(f"Processing recs for {hz} Hz")
    
#     for i, row in df.iterrows():
#         print(f"{i+1}/{df.shape[0]}")
#         p1,p2,p3 = get_all_probs(row.preds)
#         fl = resampled_first_last[row.id]

#         # Argmax
#         preds_argmax = aasm_to_wake_sleep(np.argmax(p1,axis=1))
#         resampled_preds_argmax=psuedo_resample(preds_argmax, fl)
#         resampled_preds_argmax[resampled_preds_argmax==0.5] = 1
        
#         # Sum
#         preds_sum = np.array([p2[:,1] > t for t in thresholds])*1
#         resampled_preds_sum=psuedo_resample(preds_sum, fl)
#         resampled_preds_sum[resampled_preds_sum==0.5] = 1
        
#         # Max
#         preds_max = np.array([p3[:,1] > t for t in thresholds])*1
#         resampled_preds_max=psuedo_resample(preds_max, fl)
#         resampled_preds_max[resampled_preds_max==0.5] = 1
            
#         # Store
#         entry = {"preds_argmax": resampled_preds_argmax,
#                  "preds_sum": resampled_preds_sum,
#                  "preds_max": resampled_preds_max,
#                  "labels": resampled_labels[row.id]}
#         processed_recs[row.id] = entry
#         clear_output(wait=True)


#     pickle_file = f'{hz}_processed_recs2.pickle'
#     write_to_pickle_file(processed_recs, pickle_file)
    


Dataframe creation for 8 Hz
Dataframe creation for 16 Hz
Dataframe creation for 32 Hz
Dataframe creation for 64 Hz
Dataframe creation for 128 Hz


In [152]:
def my_collector(collection, ids, key, rm=True):
    #print(f"Removing invalid labels: {rm}")
    i = 0
    for k, sub_collection in collection.items():
        if k not in ids:
            #print(f"Skipping: {k}")
            continue

        v = sub_collection[key]
        _y = sub_collection["labels"]
        _yy = remove_invalid_labels(_y, min_duration=1, max_duration=np.inf, fs = 5, verbose=False)
        
        if rm:
            v = np.array([remove_invalid_labels(vx,min_duration=1, max_duration=np.inf, fs=5, verbose=False) for vx in v])
        if i == 0:
            y_hat = v
            y = _y
        else:
            y_hat = np.column_stack([y_hat, v]) if len(y_hat.shape) > 1 else np.hstack([y_hat, v])
            y = np.hstack([y, _y])
        i += 1
    return y_hat, y

In [127]:
class Tracker(object):
    
    
    def __init__(self, k, train_y_true, train_y_hat, val_y_true, val_y_hat):
    
        self.k = k
        
        tuning = len(train_y_hat.shape) > 1
        
        # Calculate training metrics
        loop_func = lambda yt, yh, f: [f(yt,y) for y in yh]
        static_func = lambda yt, yh, f: f(yt,yh)
        func = loop_func if tuning else static_func

        self._train_precision = func(train_y_true, train_y_hat, metrics.precision_score)
        self._train_recall    = func(train_y_true, train_y_hat, metrics.recall_score)
        self._train_f1        = func(train_y_true, train_y_hat, metrics.f1_score)
        self._train_kappa     = func(train_y_true, train_y_hat, metrics.cohen_kappa_score)
       
        if tuning:
            opt_idx = np.argmax(self._train_kappa) if len(train_y_hat.shape) > 1 else 0
            self.opt_idx = opt_idx
            self.opt_threshold = thresholds[opt_idx]

            self.train_opt_kappa     = self._train_kappa[opt_idx]
            self.train_opt_f1        = self._train_f1[opt_idx]
            self.train_opt_recall    = self._train_recall[opt_idx]
            self.train_opt_precision = self._train_precision[opt_idx]

        else:
            self.opt_idx = np.nan
            self.opt_threshold = np.nan
            
            self.train_opt_kappa     = self._train_kappa
            self.train_opt_f1        = self._train_f1
            self.train_opt_recall    = self._train_recall
            self.train_opt_precision = self._train_precision
        

        # Calculate validation metrics
        val_y_opt     = val_y_hat[opt_idx,:] if tuning else val_y_hat
        
        self.val_precision = metrics.precision_score(val_y_true, val_y_opt)
        self.val_recall    = metrics.recall_score(val_y_true, val_y_opt)
        self.val_f1        = metrics.f1_score(val_y_true, val_y_opt)
        self.val_kappa     = metrics.cohen_kappa_score(val_y_true, val_y_opt)
        
    def to_dict(self):
        d = vars(self)
        out = {}
        for k,v in d.items():
            if not k.startswith("_"):
                out[k] = v
        return out

class EvaluationTracker(object):
    
    
    def __init__(self, y_true, y_hat, opt_idx = None):
    
        if np.any(opt_idx) and len(y_hat.shape) > 1:
            y_hat = y_hat[opt_idx,:]
        
        try:
            tn, fp, fn, tp = metrics.confusion_matrix(y_true, y_hat).ravel()
            self.specificity = tn / (tn+fp)
        except:
            pass
        
        self.accurcy     = metrics.accuracy_score(y_true, y_hat)
        self.precision   = metrics.precision_score(y_true, y_hat)
        self.recall      = metrics.recall_score(y_true, y_hat)
        self.f1          = metrics.f1_score(y_true, y_hat)
        self.kappa       = metrics.cohen_kappa_score(y_true, y_hat)
        
        

        


In [None]:
_df = pd.read_csv("prelim_data/corrected_8_info_df.csv")
test_df = _df[_df.type=="test"].reset_index(drop=True)
dev_df = _df[_df.type=="train"].reset_index(drop=True)


In [191]:

seed = 42
k = 5
skf = model_selection.StratifiedKFold(k, shuffle=True, random_state=seed)

pred_keys = ["preds_argmax","preds_sum","preds_max"]
HZ = [128, 64, 32, 16, 8]
i = 0
for hz in HZ:
    processed_recs_file = f"prelim_data/{hz}_processed_recs2.pickle"
    processed_recs = load_pickle_from_file(processed_recs_file)
    for pk in pred_keys:
        for k, (train_idx, val_idx) in enumerate(skf.split(dev_df.index, dev_df.ms)):   

            print(f"K: {k} - Hz: {hz} - Method: {pk}")
            train_id = dev_df.id[train_idx].values
            val_id = dev_df.id[val_idx].values
            
            train_yhat, train_y = my_collector(collection = processed_recs, ids = train_id, key=pk, rm = True)
            val_yhat, val_y     = my_collector(collection = processed_recs, ids = val_id, key=pk, rm = True)

            k_tracker = Tracker(k, train_y, train_yhat, val_y, val_yhat)
            k_dict = k_tracker.to_dict()
            k_dict["method"] = pk
            k_dict["hz"] = hz

            if i == 0:
                k_df = pd.DataFrame(k_dict, index=[i])
            else:
                k_df = pd.concat([k_df, pd.DataFrame(k_dict, index=[i])])

            clear_output(wait=True)
            i += 1
        k_df.to_csv("corrected_prelim2_df.csv")


K: 4 - Hz: 8 - Method: preds_max


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


### Diagnostics...
Go to post_hoc_analysis.ipynb

In [None]:
# bern_mapping = {"Wake": [0,2,3], "MS": [1]}
# mp = {"Wake": 0, "MS": 1}
# unilateral = False
# pat = test_df.iloc[3]
# tmp = BernLabels(pat.labels, bern_mapping, include_unilateral=unilateral)
# _, time_pos = tmp.apply_rolling_func(win=0.2, step=0.2)
# fl = np.array([[np.floor(x[0]*best_hz), np.ceil(x[-1]*best_hz)+1] for x in time_pos], dtype=int)

# _, axs = plt.subplots(3,1,sharex=True,figsize=[9,9])

# y_probs,y_sum_probs,y_max_probs = get_all_probs(f"predictions/{best_hz}_hz/{pat.id}.npy")
# plot_probs(ax=axs[0], probs=y_probs, labs=AASM_MAPPING.keys(), fs=best_hz)
# axs[0].axhline(1-t.opt_threshold,color="k",linestyle='--')

# # show prediction workflow from probabilties, median sampling, short removal
# y_pred, _y=my_collector(collection=collection, ids=pat.id, key=best_method)
# y_opt = y_pred[t.opt_idx,:]
# y_sum_preds = (y_sum_probs[:,1] > t.opt_threshold)*1
# y_sum_resampled = np.array([np.median(y_sum_preds[x[0]:x[1]]) for x in fl])
# y_sum_resampled[y_sum_resampled == 0.5] = 1
# y_sum_resampled_and_removed = remove_invalid_labels(np.copy(y_sum_resampled),  1, min_duration=1, max_duration=np.inf, fs=5)
# plot_label_patches(y_sum_preds, mp, ax=axs[1], fs=best_hz, y_min=3/4, y_max=1)
# plot_label_patches(y_sum_resampled, mp, ax=axs[1], fs=5, y_min=2/4, y_max=3/4)
# plot_label_patches(y_sum_resampled_and_removed, mp, ax=axs[1], fs = 5, y_min=1/4, y_max=2/4)
# plot_label_patches(y_opt, mp, ax=axs[1], fs=5, y_min=0, y_max=1/4)
# axs[1].set_yticks([0])
# axs[1].set_ylabel(f"Predictions with \n {t.opt_threshold} threshold")

# tmp.apply_rolling_func(win=0.2, step=0.2,replace=True)
# tmp.labels[tmp.labels==0.5] = 1
# tmp.apply_time_critera(replace=True, min_duration=1, max_duration=np.inf)
# tmp.plot_labels(ax=axs[2], fs=5)


# for ax in axs:
#     ax.autoscale(tight=True)

# #plot_label_patches(y_opt, mp, ax=ax2, fs=5, y_min=i/40, y_max=(i+1)/40)





In [None]:
# def find2(y,target):
#     r,c = np.where(y==target)
#     idx=np.empty(y.shape)
#     idx[r,c] = 1
#     islands = np.diff(idx)
#     zs = np.zeros(islands.shape[0])
#     ffy = np.column_stack([zs, islands, zs])
#     diff_idx = np.diff(ffy)

#     start_r, start_c = np.where(diff_idx > 0)
#     stop_r, stop_c = np.where(diff_idx < 0)

#     start = [(start_c[start_r==i]) for i in range(40)]
#     stop = [(stop_c[stop_r==i]) for i in range(40)]
    
#     return start, stop

In [108]:
import pandas as pd
from utils import *
from scipy.stats import mode

df = pd.read_csv("prelim_data/corrected_8_info_df.csv")
count = []
durs = []
tt = []

for i, row in df.iterrows():
    tmp = BernLabels(row.labels, {"Wake": [0,2,3], "MS":[1]}, include_unilateral=True)
    tmp.apply_rolling_func(win=0.2, step=0.2,replace=True)
    tmp.labels[tmp.labels==0.5]=1
    start, stop, singles = get_target_label_start_and_stop_indices(tmp.labels, 1)
    y=tmp.labels
    dur = (stop-start)/tmp.fs
    total_dur = np.sum(dur / 60)
    count.append(np.sum(y==1))
    durs.append(total_dur)
    tt.append(row.type)
    
    
#     for j, arr in enumerate([tmp.raw_O1, tmp.raw_O2]):
#         start, stop, singles = get_target_label_start_and_stop_indices(arr, 1)
#         dur = (stop-start)/tmp.fs
#         total_dur = np.sum(dur / 60)
#         count.append(len(start))
#         durs.append(total_dur)
#         tt.append(row.type)
#         sides.append(i)
    
    
stats = pd.DataFrame({"type": tt, "duration": durs, "count": count})
stats.groupby(["type"]).sum()

Unnamed: 0_level_0,duration,count
type,Unnamed: 1_level_1,Unnamed: 2_level_1
test,55.113333,16881
train,157.706667,48214


In [115]:
a=np.empty([1,10])
a[1:5] = 1
a

array([[0., 0., 1., 0., 1., 1., 0., 1., 0., 0.]])