In [None]:
# import models
from sleep_classif.CNNmultitaper import ConvNetMultitaper
from sleep_classif.LSTMConv import LSTM_Conv
from sleep_classif.CNNadvanced import CNN_Advanced
from sleep_classif.CNNmodel import SimpleCNN

# import loaders and other functions
from sleep_classif.preprocessing import compute_tapers
from sleep_classif.dataloaders import MultiTaperSet, RawDataSet, FFT_Raw_DataSet
from sleep_classif.trainer import Trainer

# import from other librairies 
import torch
import torch.nn as nn

## Prepare Cuda


In [None]:
if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
print(device)

## Train basic CNN network

In [None]:
data_path_train = './data/raw_data/X_train.h5'
data_path_test = './data/raw_data/X_test.h5'

target_path = './data/raw_data/y_train.csv'

torch.cuda.empty_cache()

raw_train_set = RawDataSet(device=device,
                                 data_path = data_path_train,
                                 target_path = target_path)

raw_test_set = RawDataSet(device=device,
                                 data_path = data_path_test,
                                 target_path = target_path)

In [None]:
train_ratio = 0.9
# split train and validation set
data_set = raw_train_set
train_len = int(len(data_set)*train_ratio)
train_set, validation_set = torch.utils.data.random_split(data_set, [train_len, len(data_set) - train_len])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=16, shuffle=True)

In [None]:
simple_cnn = SimpleCNN().to(device)

optimizer = torch.optim.Adam(simple_cnn.parameters())
trainer = Trainer(simple_cnn,
                  nn.CrossEntropyLoss(),
                  optimizer,
                  trainloader,
                  device,
                  valid_data_loader = validationloader,
                  class_weights=torch.Tensor([8.081897,22.222222, 2.756846, 3.765060, 4.927727]))

In [None]:
loss_list = []
accuracy_list = []
for epoch in range(0, 25):
    loss, accuracy = trainer.train_epoch()
    loss_list.append(loss)
    accuracy_list.append(accuracy)

## Train CNN + Multitaper Model

### Create MultiTapers

In [None]:
compute_tapers()

In [None]:
features_eeg_path_train = './data/pre_processed_data/Multitaper_eeg_train.npy'
features_eeg_path_test = './data/pre_processed_data/Multitaper_eeg_test.npy'

features_position_path_train = './data/pre_processed_data/Multitaper_position_train.npy'
features_position_path_test = './data/pre_processed_data/Multitaper_position_test.npy'

target_path = './data/raw_data/y_train.csv'

torch.cuda.empty_cache()

taper_train_set = MultiTaperSet(device=device,
                                features_eeg_path = features_eeg_path_train,
                                features_position_path = features_position_path_train,
                                target_path = target_path)

taper_test_set = MultiTaperSet(device=device,
                                features_eeg_path = features_eeg_path_test,
                                features_position_path = features_position_path_test,
                                target_path = target_path)


In [None]:
train_ratio = 0.9
# split train and validation set
data_set = taper_train_set
train_len = int(len(data_set)*train_ratio)
train_set, validation_set = torch.utils.data.random_split(data_set, [train_len, len(data_set) - train_len])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=len(validation_set), shuffle=True)

In [None]:
CNN_taper_model = ConvNetMultitaper().to(device)

In [None]:
optimizer = torch.optim.Adam(CNN_taper_model.parameters())
trainer = Trainer(CNN_taper_model,
                  nn.CrossEntropyLoss(),
                  optimizer,
                  trainloader,
                  device,
                  valid_data_loader = validationloader,
                  class_weights=torch.Tensor([8.081897,22.222222, 2.756846, 3.765060, 4.927727]))


In [None]:
loss_list = []
accuracy_list = []
for epoch in range(0, 25):
    loss, accuracy = trainer.train_epoch()
    loss_list.append(loss)
    accuracy_list.append(accuracy)

## Training an advanced CNN network

In [None]:
from scipy import fftpack

data_path_train = './data/raw_data/X_train.h5'
data_path_test = './data/raw_data/X_test.h5'

target_path = './data/raw_data/y_train.csv'



raw_train_set = FFT_Raw_DataSet(device=device,
                                 data_path = data_path_train,
                                 target_path = target_path)

raw_test_set = FFT_Raw_DataSet(device=device,
                                 data_path = data_path_test,
                                 target_path = target_path)






In [None]:
train_ratio = 0.9
# split train and validation set
data_set = raw_train_set
train_len = int(len(data_set)*train_ratio)
train_set, validation_set = torch.utils.data.random_split(data_set, [train_len, len(data_set) - train_len])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=len(validation_set), shuffle=True)

In [None]:
#raw_feat, fft_feat, raw_pos_feat, fft_pos_feat = raw_train_set.feature_shape()
num_classes = 5
raw_feat, fft_feat, raw_pos_feat, fft_pos_feat = 5,5,3,3

In [None]:
CNN_Advanced_model = CNN_Advanced(raw_feat, fft_feat, raw_pos_feat, fft_pos_feat, num_classes, 0.5).to(device)


In [None]:
optimizer = torch.optim.Adam(CNN_Advanced_model.parameters())
trainer = Trainer(CNN_Advanced_model,
                  nn.CrossEntropyLoss(),
                  optimizer,
                  trainloader,
                  device,
                  valid_data_loader = validationloader,
                  class_weights=torch.Tensor([8.081897,22.222222, 2.756846, 3.765060, 4.927727]))




In [None]:
loss_list = []
accuracy_list = []
for epoch in range(0, 25):
    loss, accuracy = trainer.train_epoch()
    loss_list.append(loss)
    accuracy_list.append(accuracy)

In [None]:
CNN_taper_model = ConvNetMultitaper().to(device)

In [None]:
multitaper_train_set = MultiTaperSet(device = device)
multitaper_test_set = MultiTaperSet(device=device, features_eeg_path = './data/pre_processed_data/Multitaper_eeg_test.npy', features_position_path = './data/pre_processed_data/Multitaper_position_test.npy')

In [None]:
train_ratio = 0.9
# split train and validation set
data_set = multitaper_train_set
train_len = int(len(data_set)*train_ratio)
train_set, validation_set = torch.utils.data.random_split(data_set, [train_len, len(data_set) - train_len])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=len(validation_set), shuffle=True)

In [None]:
optimizer = torch.optim.Adam(CNN_taper_model.parameters())
trainer = Trainer(CNN_taper_model,
                 nn.CrossEntropyLoss(),
                 optimizer,trainloader,
                 device,
                 valid_data_loader = validationloader,
                 class_weights=torch.Tensor([8.081897,22.222222, 2.756846, 3.765060, 4.927727])
                 )

In [None]:
loss_list = []
accuracy_list = []
for epoch in range(0, 25):
    loss, accuracy = trainer.train_epoch()
    loss_list.append(loss)
    accuracy_list.append(accuracy)

## Train CNN + LSTM

In [None]:
data_path_train = './data/raw_data/X_train.h5'
data_path_test = './data/raw_data/X_test.h5'

target_path = './data/raw_data/y_train.csv'



raw_train_set = RawDataSet(device=device,
                                 data_path = data_path_train,
                                 target_path = target_path)

raw_test_set = RawDataSet(device=device,
                                 data_path = data_path_test,
                                 target_path = target_path)





In [None]:
train_ratio = 0.9
# split train and validation set
data_set = raw_train_set
train_len = int(len(data_set)*train_ratio)
train_set, validation_set = torch.utils.data.random_split(data_set, [train_len, len(data_set) - train_len])

trainloader = torch.utils.data.DataLoader(train_set, batch_size=64, shuffle=True)
validationloader = torch.utils.data.DataLoader(validation_set, batch_size=len(validation_set), shuffle=True)

In [None]:
raw_feat = raw_train_set.feature_shape()
num_classes = 5


In [None]:
LSTM_Conv_model = LSTM_Conv(raw_feat, num_classes)
LSTM_Conv_model = LSTM_Conv_model.to(device)

In [None]:
optimizer = torch.optim.Adam(LSTM_Conv_model.parameters())

trainer = Trainer(LSTM_Conv_model,
                  nn.CrossEntropyLoss(),
                  optimizer,
                  trainloader,
                  device,
                  valid_data_loader = validationloader,
                  class_weights=torch.Tensor([8.081897,22.222222, 2.756846, 3.765060, 4.927727]), 
                  requires_softmax = True)


In [None]:
loss_list = []
accuracy_list = []
for epoch in range(0, 25):
    loss, accuracy = trainer.train_epoch()
    loss_list.append(loss)
    accuracy_list.append(accuracy)