In [4]:
## https://drive.google.com/drive/folders/1vokakruDSSCE33GKZlis25OTSSw95mAD?usp=sharing

In [5]:
# Imports

from scipy.io import loadmat
from scipy.signal import fftconvolve
import numpy as np
import os
import random
import gc as garbageCollector
from multiprocessing import Pool

'''Script to extract preprocessed signals and save to temporary numpy memap files to be used for model training'''

########################################################################################################################
# Function to compute consistent train/validation/test split by a constant random seed

def computeTrainValidationTestRecords(dataPath, foldName='Default'):
    print('////////////////////')
    print('Computing train/validation/test data split, Test results are:')

    # Define consistent mappings for the index from training data to the actual returned value to ensure shuffling
    r = random.Random()
    r.seed(29)
    requestRecordMap = list(range(994))
    r.shuffle(requestRecordMap)

    #########
    if foldName == 'Default':
        testIndices = requestRecordMap[0:200]
        validationIndices = requestRecordMap[894::]
        trainIndices = requestRecordMap[200:894]
    elif foldName == 'Auxiliary1':
        testIndices = requestRecordMap[0:100]
        validationIndices = requestRecordMap[100:200]
        trainIndices = requestRecordMap[200::]
    elif foldName == 'Auxiliary2':
        testIndices = requestRecordMap[0:100]
        validationIndices = requestRecordMap[300:400]
        trainIndices = requestRecordMap[100:300]
        trainIndices.extend(requestRecordMap[400::])
    elif foldName == 'Auxiliary3':
        testIndices = requestRecordMap[0:100]
        validationIndices = requestRecordMap[600:700]
        trainIndices = requestRecordMap[100:600]
        trainIndices.extend(requestRecordMap[700::])
    elif foldName == 'Auxiliary4':
        testIndices = requestRecordMap[0:100]
        validationIndices = requestRecordMap[894::]
        trainIndices = requestRecordMap[100:894]

    #########

    recordList = list(filter(lambda x: os.path.isdir(dataPath + x),
                        os.listdir(dataPath)))
    trainRecordList = [recordList[ind] for ind in trainIndices]
    validationRecordList = [recordList[ind] for ind in validationIndices]
    testRecordList = [recordList[ind] for ind in testIndices]

    # Small test to make sure the sets are non overlapping and use all of the data
    areUnique = len(list(set(trainRecordList).intersection(set(validationRecordList).intersection(testRecordList)))) == 0
    isComplete = len(list(set(trainRecordList).union(set(validationRecordList).union(testRecordList)).union(set(recordList)))) == len(recordList)

    print('Uniqueness Test: ' + str(areUnique) + '  Completeness Test: ' + str(isComplete))
    print('////////////////////\n')

    return trainRecordList, validationRecordList, testRecordList

def computeUnseenRecords(dataPath):
    recordList = filter(lambda x: os.path.isdir(dataPath + x),
                        os.listdir(dataPath))

    return recordList


########################################################################################################################
# Load signals or annotations from a specific file in the source files

# Convenience function to load signals
def loadSignals(recordName, dataPath):
    signals = loadmat(dataPath + recordName + '/' + recordName + '.mat')
    signals = signals['val']
    garbageCollector.collect()

    return signals

# Convenience function to load annotations
def loadAnnotations(recordName, arousalAnnotationPath, apneaHypopneaAnnotationPath, sleepWakeAnnotationPath):

    arousalAnnotations = loadmat(arousalAnnotationPath + recordName + '-arousal.mat')['data']['arousals'][0][0]
    arousalAnnotations = np.squeeze(arousalAnnotations.astype(np.int32))
    garbageCollector.collect()

    fp = np.memmap(apneaHypopneaAnnotationPath + 'apneaHypopneaAnnotation_' + str(recordName) + '.dat', dtype='int32', mode='r')
    apneaHypopneaAnnotations = np.zeros(shape=fp.shape)
    apneaHypopneaAnnotations[:] = fp[:]

    fp = np.memmap(sleepWakeAnnotationPath + 'sleepWakeAnnotation_' + str(recordName) + '.dat', dtype='int32',  mode='r')
    sleepStageAnnotations = np.zeros(shape=fp.shape)
    sleepStageAnnotations[:] = fp[:]

    return arousalAnnotations, apneaHypopneaAnnotations, sleepStageAnnotations

########################################################################################################################
# Load signals or annotations from a specific file in the source files

def extractWholeRecord(recordName,
                       dataPath,
                       arousalAnnotationPath,
                       apneaHypopneaAnnotationPath,
                       sleepWakeAnnotationPath,
                       extractAnnotations=True):
    # Keep all EEG channels
    keepChannels = [0, 1, 2, 3, 4, 5]

    if extractAnnotations:
        arousalAnnotations, apneaHypopneaAnnotations, sleepStageAnnotations = loadAnnotations(recordName, arousalAnnotationPath, apneaHypopneaAnnotationPath, sleepWakeAnnotationPath)

    signals = loadSignals(recordName, dataPath)
    signals = np.transpose(signals).astype(np.float64)

    # Apply antialiasing FIR filter to each channel and downsample to 50Hz
    filtCoeff = np.array([0.00637849379422531, 0.00543091599801427, -0.00255136650039784, -0.0123109503066702,
                          -0.0137267267561505, -0.000943230632358082, 0.0191919895027550, 0.0287148886882440,
                          0.0123598773891149, -0.0256928886371578, -0.0570987715759348, -0.0446385294777459,
                          0.0303553522906817, 0.148402006671856, 0.257171285176269, 0.301282456398562,
                          0.257171285176269, 0.148402006671856, 0.0303553522906817, -0.0446385294777459,
                          -0.0570987715759348, -0.0256928886371578, 0.0123598773891149, 0.0287148886882440,
                          0.0191919895027550, -0.000943230632358082, -0.0137267267561505, -0.0123109503066702,
                          -0.00255136650039784, 0.00543091599801427, 0.00637849379422531])

    for n in range(signals.shape[1]):
        signals[::, n] = np.convolve(signals[::, n], filtCoeff, mode='same')

    signals = signals[0::4, keepChannels]
    if extractAnnotations:
        arousalAnnotations = arousalAnnotations[0::4]
        apneaHypopneaAnnotation = apneaHypopneaAnnotations[0::4]
        sleepStageAnnotation = sleepStageAnnotations[0::4]

    garbageCollector.collect()

    # Scale SaO2 to sit between -0.5 and 0.5, a good range for input to neural network
    signals[::, 11] += -32768.0
    signals[::, 11] /= 65535.0
    signals[::, 11] -= 0.5

    # Normalize all the other channels by removing the mean and the rms in an 18 minute rolling window, using fftconvolve for computational efficiency
    # 18 minute window is used because because baseline breathing is established in 2 minute window according to AASM standards.
    # Normalizing over 18 minutes ensure a 90% overlap between the beginning and end of the baseline window
    kernel_size = (50*18*60)+1

    # Remove DC bias and scale for FFT convolution
    center = np.mean(signals, axis=0)
    scale = np.std(signals, axis=0)
    scale[scale == 0] = 1.0
    signals = (signals - center) / scale

    # Compute and remove moving average with FFT convolution
    center = np.zeros(signals.shape)
    for n in range(signals.shape[1]):
        center[::, n] = fftconvolve(signals[::, n], np.ones(shape=(kernel_size,))/kernel_size, mode='same')

    # Exclude SAO2
    center[::, 11] = 0.0
    center[np.isnan(center) | np.isinf(center)] = 0.0
    signals = signals - center

    # Compute and remove the rms with FFT convolution of squared signal
    scale = np.ones(signals.shape)
    for n in range(signals.shape[1]):
        temp = fftconvolve(np.square(signals[::, n]), np.ones(shape=(kernel_size,))/kernel_size, mode='same')

        # Deal with negative values (mathematically, it should never be negative, but fft artifacts can cause this)
        temp[temp < 0] = 0.0

        # Deal with invalid values
        invalidIndices = np.isnan(temp) | np.isinf(temp)
        temp[invalidIndices] = 0.0
        maxTemp = np.max(temp)
        temp[invalidIndices] = maxTemp

        # Finish rms calculation
        scale[::, n] = np.sqrt(temp)

    # Exclude SAO2
    scale[::, 11] = 1.0

    scale[(scale == 0) | np.isinf(scale) | np.isnan(scale)] = 1.0  # To correct for record 12 that has a zero amplitude chest signal
    signals = signals / scale

    garbageCollector.collect()

    # Convert to 32 bits
    signals = signals.astype(np.float32)

    if extractAnnotations:
        arousalAnnotations = np.expand_dims(arousalAnnotations, axis=1).astype(np.float32)
        apneaHypopneaAnnotation = np.expand_dims(apneaHypopneaAnnotation, axis=1).astype(np.float32)
        sleepStageAnnotation = np.expand_dims(sleepStageAnnotation, axis=1).astype(np.float32)
        return np.concatenate([signals, arousalAnnotations, apneaHypopneaAnnotation, sleepStageAnnotation], axis=1)
    else:
        return signals

########################################################################################################################
# Functions to extract all three datasets

def extractBatchedData(datasetName,
                       dataPath,
                       savePath,
                       foldName,
                       dataLimitInHours=7,
                       useMultiprocessing=False):

    # Get list of records and divide into train, validation and testing based on pre defined randomly divided folds
    trainRecordList, validationRecordList, testRecordList = computeTrainValidationTestRecords(dataPath, foldName=foldName)
    testRecordList = None

    if datasetName == 'Training':
        recordList = trainRecordList
        validationRecordList = None
    elif datasetName == 'Validation':
        recordList = validationRecordList
        trainRecordList = None
    else:
        raise Exception('Invalid data set name for batched data extraction!')

    # Set up
    numRecordsPerBatch = 11
    batchCount = 0
    sampleDataLimit = dataLimitInHours*3600*50

    extractDatasetFilePath = savePath + datasetName + '_'

    if useMultiprocessing:
        pool = Pool(processes=numRecordsPerBatch)

    fp = np.memmap(extractDatasetFilePath + 'data.dat', dtype='float32', mode='w+', shape=(len(recordList), sampleDataLimit, 15))

    # Extract data in batches of numRecordsPerBatch
    for n in range(0, len(recordList), numRecordsPerBatch):

        print('Extracting Batch: ' + str(batchCount))

        # Compute upper limit for this batch of records to extract
        limit = n+numRecordsPerBatch
        if (n+numRecordsPerBatch) > len(recordList):
            limit = len(recordList)

        # Process and extract batch of records
        if useMultiprocessing:
            data = pool.map(extractWholeRecord, recordList[n:limit])
        else:
            data = list(map(extractWholeRecord, recordList[n:limit]))

        # Enforce dataLimitInHours hour length with chopping / zero padding for memory usage stability and effficiency in cuDNN
        for n in range(len(data)):
            originalLength = data[n].shape[0]/(3600*50)
            if data[n].shape[0] < sampleDataLimit:
                # Zero Pad
                neededLength = sampleDataLimit - data[n].shape[0]
                extension = np.zeros(shape=(neededLength, data[n].shape[1]))
                extension[::, -3::] = -1.0
                data[n] = np.concatenate([data[n], extension], axis=0)

            elif data[n].shape[0] > sampleDataLimit:
                # Chop
                data[n] = data[n][0:sampleDataLimit, ::]

            print('Original Length: ' + str(originalLength) + '  New Length: ' + str(data[n].shape[0]/(3600*50)))

        garbageCollector.collect()

        print('Saving Batch: ' + str(batchCount))

        # Save batch of extracted records to extracted data path
        for m in range(len(data)):
            fp[(batchCount*numRecordsPerBatch)+m, ::, ::] = data[m][::, ::]
            assert isinstance(fp, np.memmap)
            fp.flush()

        garbageCollector.collect()

        batchCount += 1

    del fp

# Extract training and validation sets from backing store, YOU NEED TO COMPLETE THE REQURIED INPUTS BASED ON YOUR ENVIRONMENT
extractBatchedData(useMultiprocessing=True, datasetName='Training', foldName='Auxiliary4', dataPath, savePath)
extractBatchedData(useMultiprocessing=True, datasetName='Validation', foldName='Auxiliary4', dataPath, savePath)

## Compute the covaraince matrices for different records 

we do this in an unsupervise setting (no labeled needed in this stage)

In [6]:
# we used pyriemann library which is a machine learning library based on scikit-learn API
import sys
from pyriemann.estimation import Covariances
from pyriemann.clustering import Kmeans
from pyriemann.utils.distance import pairwise_distance
from sklearn.manifold import TSNE
from pyriemann.tangentspace import TangentSpace
from pyriemann.tangentspace import tangent_space


## Note: if covaraince matrix is not SPD, it means that at least one of variables 
## can be expressed as a linear combinanttion of the others. 

recordList = sys.argv[1:]

COV = []

for record in recordList:
    
    processedSignal = extractWholeRecord(recordName=str(record),
                                             dataPath='./',
                                             dataInDirectory=False)
    processedSignal.apply(zscore)
    xtrain =  processedSignal.values
    
    # In order to create symmetric positive definite matrices we consider sigma (shrinkage model)
    
    COV = np.append(COV, [np.cov(xtrain.T) + sigma * np.identity(processedSignal.shape[1])], axis=0)


### Clustering on a Riemannian manifold

In [7]:
kmeans= Kmeans(10).fit(COV)
clusters = kmeans.predict(COV)


# Plotting the clustering results

distmatrix = pairwise_distance(COV, metric='riemann')
X_embedded = TSNE(n_components=2, metric='precomputed').fit_transform(distmatrix)

fig, ax = plt.subplots(figsize=(8, 8), facecolor='white')

target_ids = np.unique(clusters)
colors = 'r', 'g', 'b', 'c', 'm'
for i, c, label in zip(target_ids, colors, target_ids):
    ax.scatter(X_embedded [clusters == i , 0], X_embedded[clusters == i, 1], c=c, label=label)

ax.set_xlabel(r'$\varphi_1$', fontsize=16)
ax.set_ylabel(r'$\varphi_2$', fontsize=16)
ax.set_title('Sub-grouping the population using a Riemannian Distance', fontsize=16)

ax.grid(False)
ax.legend()
plt.show()

### Training a CNN model for each cluster 

#### Mapping covariance matrices to tangent space


Tangent space project TransformerMixin.

Tangent space projection map a set of covariance matrices to their tangent space. The Tangent space projection can be seen as a kernel operation. After projection, each matrix is represented as a vector of size 𝑁(𝑁+1)/2 where N is the dimension of the covariance matrices.

Tangent space projection is useful to convert covariance matrices in euclidean vectors while conserving the inner structure of the manifold. After projection, standard processing and vector-based classification can be applied.

Tangent space projection is a local approximation of the manifold. it takes one parameter, the reference point, that is usually estimated using the geometric mean of the covariance matrices set you project. if the function fit is not called, the identity matrix will be used as reference point. This can lead to serious degradation of performances. The approximation will be bigger if the matrices in the set are scattered in the manifold, and lower if they are grouped in a small region of the manifold.

After projection, it is possible to go back in the manifold using the inverse transform.

In [8]:
# Input: X : ndarray, shape (n_trials, n_channels, n_channels)
# Output: Y: ts : ndarray, shape (n_trials, n_ts): the tangent space projection of the matrices.

from pyriemann.utils.mean import mean_ale

Features = []
i = 0 
for record in recordList:
    
    
    covmats = COV [i] # Covariance matrices set, Ntrials X Nchannels X Nchannels
    i = i + 1
    Cref = mean_ale(covmats, tol=1e-06, maxiter=50, sample_weight=None) # the mean covariance matrix
    clf = TangentSpace.tangent_space(covmats, Cref)
    
    # In order to create symmetric positive definite matrices we consider sigma (shrinkage model)

    Features = np.append(Features, clf, axis=0) # Data are in the tangent space which is a locally Euclidean 

## CNN training
The model is similar to the model is used in "Supervised and unsupervised machine learning for automated scoring of sleep-wake and cataplexy in a mouse model of narcolepsy", but we used 5-stage classification

All algorithms implemented in the mentioned paper can be made available upon request.
Therefore, we have released only trained models on our authorization. 

In [3]:
import numpy as np
import torch
import multiprocessing as mltp
import gc as garbageCollector
from random import Random

########################################################################################################################
#

# load X_tr and Y_tr (corresponding to each record and 30-sec)
# train a CNN model 
# compute confusion matrix 

K = 5 


import torch
import torch.optim as optim
import os
import gc as garbageCollector

from ModelDefinition import Sleep_model_MultiTarget

from TrainingDataManager import asyncTrainingDataLoader
from ValidationDataManager import asyncValidationDataLoader, evalValidationAccuracy

import timeit

torch.backends.cudnn.benchmark = True

# Settings

# Control parameters
gpuSetting = 1     # Which GPU to use
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpuSetting)

foldName = 'Auxiliary1'

numChannels = 12

########################################################################################################################
# Training and validation data asynchronous managers

extractedDataPath = "UPDATE BASED ON YOUR ENVIRONMENT"
trainingData = asyncTrainingDataLoader(extractedDataPath=extractedDataPath, reductionFactor=50, numChannels=numChannels)
validationData = asyncValidationDataLoader(extractedDataPath=extractedDataPath, numRecords=100, reductionFactor=50, numChannels=numChannels)
garbageCollector.collect()

########################################################################################################################

# Trains the model
def trainModel(model):

    optimizer = optim.Adam(model.parameters())

    arousalCriterion = torch.nn.CrossEntropyLoss(ignore_index=-1, reduce=True).cuda()
    apneaCriterion = torch.nn.CrossEntropyLoss(ignore_index=-1, reduce=True).cuda()
    wakeCriterion = torch.nn.CrossEntropyLoss(ignore_index=-1, reduce=True).cuda()

    bestArousalAUC = 0.5
    bestArousalAP = 0.0

    bestApneaAUC = 0.5
    bestApneaAP = 0.0

    bestWakeAUC = 0.5
    bestWakeAP = 0.0

    i_epoch = 0

    numBatchesPerEpoch = 100

    while True:
        # Put in train mode
        trainingStartTime = timeit.default_timer()

        model.train()
        runningLoss = 0.0

        for n in range(numBatchesPerEpoch):

            # Fetch pre-loaded batch from asynchronous data loader
            batchFeats, batchArousalTargs, batchApneaTargs, batchWakeTargs = trainingData.getNextItem()

            # Send batch to GPU
            batchFeats = batchFeats.cuda()
            batchArousalTargs = batchArousalTargs.cuda()
            batchApneaTargs = batchApneaTargs.cuda()
            batchWakeTargs = batchWakeTargs.cuda()

            # Compute the network outputs on the batch
            arousalOutputs, apneaHypopneaOutputs, sleepStageOutputs = model(batchFeats)

            # Compute the losses
            arousalOutputs = arousalOutputs.permute(0, 2, 1).contiguous().view(-1, 2)
            apneaHypopneaOutputs = apneaHypopneaOutputs.permute(0, 2, 1).contiguous().view(-1, 2)
            sleepStageOutputs = sleepStageOutputs.permute(0, 2, 1).contiguous().view(-1, 2)

            arousalLoss = arousalCriterion(arousalOutputs, batchArousalTargs)
            apneaHypopneaLoss = apneaCriterion(apneaHypopneaOutputs, batchApneaTargs)
            sleepStageLoss = wakeCriterion(sleepStageOutputs, batchWakeTargs)

            loss = ((2*arousalLoss) + apneaHypopneaLoss + sleepStageLoss) / 4.0

            # Backpropagation
            loss.backward()

            # Perform one optimization step
            currentBatchLoss = loss.data.cpu().numpy()
            runningLoss += currentBatchLoss

            optimizer.step()
            optimizer.zero_grad()

        i_epoch += 1

        trainingEndTime = timeit.default_timer()

        # Get validation accuracy
        arousalAUC, arousalAP, apneaAUC, apneaAP, wakeAUC, wakeAP = evalValidationAccuracy(model, validationData=validationData)

        validationEndTime = timeit.default_timer()

        print('////////////////////')
        print('Epoch Number: ' + str(i_epoch) + '  Training Time: ' + str(trainingEndTime - trainingStartTime) + '  Validation Time: ' + str(validationEndTime - trainingEndTime))
        print('Average Training Loss: ' + str(runningLoss / float(numBatchesPerEpoch)))
       

        f = open('./Models/Auxiliary1' + '/checkpointModel_' + str(i_epoch) + '.pkl', 'wb')
        torch.save(model, f)
        f.close()

########################################################################################################################
# Create and train model

# Create new model and optimizer
model = Sleep_model_MultiTarget(numSignals=numChannels)
model.cuda()
model.train()

# Train and evaluate the model
model = trainModel(model=model)

for k in range(0,K):
    
    print ("%%%%%%%%%%%%%%%%%%%%%%%%% working on cluster " + str(k))
    
    Cluster = RecordName [clusters_total == k]
    
    TrainFold = pd.unique(fold [fold.subset == 'train'].id)
    TestFold = pd.unique(fold [fold.subset == 'test'].id)
    VldFold = pd.unique(fold [fold.subset == 'cv'].id)

    ########## Across studies
    TrainFiles = list(set(TrainFold).intersection(Cluster))
    TestFiles = list(set(TestFold).intersection(Cluster))
    VldFiles = list(set(VldFold).intersection(Cluster))
    
    N_files_train.append(len(TrainFiles)) 
    N_files_test.append(len(TrainFiles)) 
    N_files_vld.append(len(TrainFiles)) 
    
    Xtr = np.empty((0, Features.shape[1]))
    Ytr = np.empty((0))
    
    #create training dataset for the cluster
    for i in TrainFiles:
    
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xtr = np.concatenate((Xtr, X.values), axis=0)
        Y = Labels[Subjects['subject']== i]
        label_tr = Y.values.reshape(len(Y))
        Ytr = np.concatenate((Ytr, label_tr))
    
    # create test dataset for the cluster
    Xte = np.empty((0, Features.shape[1]))
    Yte = np.empty((0))
    print("wokring on test records")
    for i in TestFiles:
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xte = np.concatenate((Xte, X.values), axis=0)
        Y = Labels[Subjects.subject == i]
        label_te = Y.values.reshape(len(Y))
        Yte = np.concatenate((Yte, label_te))

    # create Validation dataset for the cluster
    Xvld = np.empty((0, Features.shape[1]))
    Yvld = np.empty((0))
    print("wokring on test records")
    for i in VldFiles:
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xvld = np.concatenate((Xvld, X.values), axis=0)
        Y = Labels[Subjects.subject == i]
        label_vld = Y.values.reshape(len(Y))
        Yvld = np.concatenate((Yvld, label_vld))

    file = open('TestDataset_features_cluster' + str(k) + '.pkl','wb')
    pickle.dump(Xte, file)
    pickle.dump(Yte, file)

    file = open('VldDataset_features_cluster' + str(k) + '.pkl','wb')
    pickle.dump(Xvld, file)
    pickle.dump(Yvld, file)


    #Randomly Shuffle the data - 5 class classification
    s_tr = np.arange(X_tr_resampled.shape[0])
    xtrain =  X_tr_resampled[s_tr]
    ytrain =  Y_tr_resampled[s_tr]

    s_vld = np.arange(Yvld.shape[0])
    xvald =  Xvld[s_vld]
    yvald =  Yvld[s_vld]

    s_te = np.arange(Yte.shape[0])
    xtest =  Xte[s_te]
    ytest =  Yte[s_te]
    
    ## Multi-Stage sleep staging (5-class classification) 
    
    ytr = np.empty_like (ytrain)
    ytr[:] = ytrain
    
    yvld = np.empty_like (yvald)
    yvld[:] = yvald

    
    
    

    file = open('TrainedClassifier_Cluster' + str(k) +'.pkl','wb')
    pickle.dump(mlp, file)
    predictions = mlp.predict(xtest1)
    report = classification_report(ytest1, predictions)
    print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% cluster" + str(k))
    print (report)
    print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")

    from sklearn import metrics
    import seaborn as sns
    LABELS = np.unique(ytest)
    
    plt.figure(figsize=(8, 8))
    cm = confusion_matrix(ytest1, predictions)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm, xticklabels=LABELS, yticklabels = LABELS, annot=True, fmt="f");
    plt.title("Confusion matrix")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('ConfusionMatrix_Cluster' + str(k) +'.pdf',bbox_inches = 'tight',
        pad_inches = 0)
    plt.show();


In [None]:
# load X_tr and Y_tr (corresponding to each record and 30-sec)
# train a CNN model 
# compute confusion matrix 

K = 5 


import torch
import torch.optim as optim
import os
import gc as garbageCollector

from ModelDefinition import Sleep_model_MultiTarget

from TrainingDataManager import asyncTrainingDataLoader
from ValidationDataManager import asyncValidationDataLoader, evalValidationAccuracy

import timeit

torch.backends.cudnn.benchmark = True

# training CNN for each cluster: each cluster has training, validation, and test subjects

for k in range(0,K):
    
    print ("%%%%%%%%%%%%%%%%%%%%%%%%% working on cluster " + str(k))
    
    Cluster = RecordName [clusters_total == k]
    
    TrainFold = pd.unique(fold [fold.subset == 'train'].id)
    TestFold = pd.unique(fold [fold.subset == 'test'].id)
    VldFold = pd.unique(fold [fold.subset == 'cv'].id)

    ########## Across studies
    TrainFiles = list(set(TrainFold).intersection(Cluster))
    TestFiles = list(set(TestFold).intersection(Cluster))
    VldFiles = list(set(VldFold).intersection(Cluster))
    
    N_files_train.append(len(TrainFiles)) 
    N_files_test.append(len(TrainFiles)) 
    N_files_vld.append(len(TrainFiles)) 
    
    Xtr = np.empty((0, Features.shape[1]))
    Ytr = np.empty((0))
    
    #create training dataset for the cluster
    for i in TrainFiles:
    
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xtr = np.concatenate((Xtr, X.values), axis=0)
        Y = Labels[Subjects['subject']== i]
        label_tr = Y.values.reshape(len(Y))
        Ytr = np.concatenate((Ytr, label_tr))
    
    # create test dataset for the cluster
    Xte = np.empty((0, Features.shape[1]))
    Yte = np.empty((0))
    print("wokring on test records")
    for i in TestFiles:
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xte = np.concatenate((Xte, X.values), axis=0)
        Y = Labels[Subjects.subject == i]
        label_te = Y.values.reshape(len(Y))
        Yte = np.concatenate((Yte, label_te))

    # create Validation dataset for the cluster
    Xvld = np.empty((0, Features.shape[1]))
    Yvld = np.empty((0))
    print("wokring on test records")
    for i in VldFiles:
        print (i)
        X = Features[Subjects.subject == i]
        X.apply(zscore)
        Xvld = np.concatenate((Xvld, X.values), axis=0)
        Y = Labels[Subjects.subject == i]
        label_vld = Y.values.reshape(len(Y))
        Yvld = np.concatenate((Yvld, label_vld))

    file = open('TestDataset_features_cluster' + str(k) + '.pkl','wb')
    pickle.dump(Xte, file)
    pickle.dump(Yte, file)

    file = open('VldDataset_features_cluster' + str(k) + '.pkl','wb')
    pickle.dump(Xvld, file)
    pickle.dump(Yvld, file)


    #Randomly Shuffle the data - 5 class classification
    X_tr_resampled, y_tr_resampled = ADSYN().fit_resample(Xtr, Ytr) # Resampling for training step
    s_tr = np.arange(X_tr_resampled.shape[0])
    xtrain =  X_tr_resampled[s_tr]
    ytrain =  Y_tr_resampled[s_tr]

    s_vld = np.arange(Yvld.shape[0])
    xvald =  Xvld[s_vld]
    yvald =  Yvld[s_vld]

    s_te = np.arange(Yte.shape[0])
    xtest =  Xte[s_te]
    ytest =  Yte[s_te]
    
    ## Multi-Stage sleep staging (5-class classification) 
    
    ytr = np.empty_like (ytrain)
    ytr[:] = ytrain
    
    yvld = np.empty_like (yvald)
    yvld[:] = yvald
    
    
    ### Training 5 CNN models for each cluster

    file = open('TrainedClassifier_Cluster' + str(k) +'.pkl','wb')
    pickle.dump(mlp, file)
    predictions = mlp.predict(xtest1)
    report = classification_report(ytest1, predictions)
    print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% cluster" + str(k))
    print (report)
    print ("%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%")

    from sklearn import metrics
    import seaborn as sns
    LABELS = np.unique(ytest)
    
    plt.figure(figsize=(8, 8))
    cm = confusion_matrix(ytest1, predictions)
    cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    sns.heatmap(cm, xticklabels=LABELS, yticklabels = LABELS, annot=True, fmt="f");
    plt.title("Confusion matrix")
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('ConfusionMatrix_Cluster' + str(k) +'.pdf',bbox_inches = 'tight',
        pad_inches = 0)
    plt.show();