In [9]:
import random
import time
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy as sk
import pickle
import mne
import tqdm
from tqdm import trange
import os

DEVICE='cuda'
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.optim as optim

from torch import nn

In [10]:
def extract_data(sub_range,verbose=False,feature_func=None,**kwargs):
    first=sub_range[0]
    last=sub_range[-1]

    for sub in trange(first,last+1): 
        #read data
        if sub <=9:
            path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
        else:
            path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
        datatemp=np.load(path,allow_pickle=True)
        datatemp=np.permute_dims(datatemp,[1,0,2])

        #if we want to convert features
        if not feature_func==None:
            datatemp=feature_func(datatemp,info=info)

        #cat data along new dimension
        datatemp=torch.tensor(datatemp).unsqueeze(0)
        if sub==first:
            data=datatemp
        else:
            data=torch.cat([data,datatemp],axis=0)

        if verbose:
            print(f'subject {sub} data extracted, vector size: {data.shape}')
    return data

def shuffle_and_split_data(X,y):
    #get shuffled train/test/val in 8:1:1 ratio

    N=X.size(0)
    shuffled_indices=torch.randperm(N) #get shuffled indices
    X=X[shuffled_indices]
    y=y[shuffled_indices]

    # split 80% train set
    train_size=int(0.8*N)
    X_train=X[:train_size]
    y_train=y[:train_size]
    X_rest=X[train_size:]
    y_rest=y[train_size:]

    #split remainder by half for test and val
    splitind=int(len(X_rest)/2)
    X_val=X_rest[:splitind]
    y_val=y_rest[:splitind]
    X_test=X_rest[splitind:]
    y_test=y_rest[splitind:]

    return X_test,y_test,X_val,y_val,X_train,y_train

#create labels
#sad-dis-fear-neu-joy-ten-ins correspond to 0-6 respectively, each sample has 21 trials, for (7 emotions x 3 trials)
labels=np.array([])
for i in range(0,7):
    for n in range(0,3):
        labels=np.concatenate([labels,np.array([i])],axis=0)
labels=torch.tensor(labels)

  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-0{sub}\ses-vid\eeg\sub-0{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'
  path=f'ds005540\derivatives\sub-{sub}\ses-vid\eeg\sub-{sub}_ses-vid_task-emotion_reorder.npy'


In [11]:
data=extract_data([1,30])

100%|██████████| 30/30 [00:07<00:00,  4.21it/s]


In [12]:
def process_data(data):
    '''0-10s, 5-15s, 10-20s, 15-25, 20-30s ---> 5 samples per trial'''
    for interval in tqdm.tqdm(range(0,5)):
        datatrial=data[:,:,:,interval*1000:(interval+2)*1000]
        datatrial=torch.unsqueeze(datatrial,2)
        if interval==0:
            newdata=datatrial
        else:
            newdata=torch.concatenate((newdata,datatrial),2)
    return newdata

In [13]:
data=process_data(data)
data.shape #subject,label,trials,electrodes,timepoints

100%|██████████| 5/5 [00:03<00:00,  1.33it/s]


torch.Size([30, 21, 5, 64, 2000])

In [14]:
for label in tqdm.tqdm(range(21)): #combine labels and trials
    temp=data[:,label,:,:,:]
    if label==0:
        X=temp
        y=labels[label]*torch.ones(5,1)
    else:
        X=torch.concatenate((X,temp),dim=1)
        y=torch.concatenate((y,labels[label]*torch.ones(5,1)),dim=0)

#combine subjects
for subject in tqdm.tqdm(range(X.size(0))):
    temp=X[subject,:,:,:]
    if subject==0:
        X_full=temp
        y_full=y
    else:
        X_full=torch.concatenate((X_full,temp),dim=0)
        y_full=torch.concatenate((y_full,y),dim=0)

print(X_full.shape)
print(y_full.shape)

#split data
X_test,y_test,X_val,y_val,X_train,y_train=shuffle_and_split_data(X_full.unsqueeze(1),y_full.squeeze())

print(X_test.shape,y_test.shape,
      X_val.shape,y_val.shape,
      X_train.shape,y_train.shape)

100%|██████████| 21/21 [00:09<00:00,  2.28it/s]
100%|██████████| 30/30 [00:24<00:00,  1.21it/s]


torch.Size([3150, 64, 2000])
torch.Size([3150, 1])
torch.Size([315, 1, 64, 2000]) torch.Size([315]) torch.Size([315, 1, 64, 2000]) torch.Size([315]) torch.Size([2520, 1, 64, 2000]) torch.Size([2520])


In [15]:
#now the model
class Net(nn.Module):
    def __init__(self,dropouts):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,8,9,stride=1,padding=4)
        self.conv2=nn.Conv2d(8,12,9,stride=1,padding=4)
        self.conv3=nn.Conv2d(12,16,5,stride=1,padding=2)
        self.conv4=nn.Conv2d(16,12,5,stride=1,padding=2)
        self.conv5=nn.Conv2d(12,2,5,stride=1,padding=2)
        self.pool1=nn.MaxPool2d(4)
        self.fc1=nn.Linear(1000,350)
        self.fc2=nn.Linear(350,350)
        self.fc3=nn.Linear(350,7)
        self.dropout1=nn.Dropout2d(dropouts[0])
        self.dropout2=nn.Dropout(dropouts[1])
        self.batchnorm1=nn.BatchNorm2d(8)
        self.batchnorm2=nn.BatchNorm2d(12)
        self.batchnorm3=nn.BatchNorm2d(16)
        self.batchnorm4=nn.BatchNorm2d(2)

    def forward(self,x):
        x=self.conv1(x)
        x=self.batchnorm1(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv2(x)
        x=self.batchnorm2(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool1(x)
        x=self.conv3(x)
        x=self.batchnorm3(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv4(x)
        x=self.batchnorm2(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.conv5(x)
        x=self.batchnorm4(x)
        x=self.dropout1(x)
        x=nn.functional.relu(x)
        x=self.pool1(x)
        x=torch.flatten(x,1)
        x=self.fc1(x)
        x=self.dropout2(x)
        x=nn.functional.relu(x)
        x=self.fc2(x)
        x=self.dropout2(x)
        x=nn.functional.relu(x)
        x=self.fc3(x)
        x=nn.functional.relu(x)
        return x
    


def train_test(net,epochs,train_loader,test_loader,device):
    criterion=nn.CrossEntropyLoss()
    optimizer=optim.Adam(net.parameters(),lr=0.1)
    train_acc=[]
    train_loss=[]
    test_acc=[]
    test_loss=[]
    net.to(device)
    for epoch in tqdm.tqdm(range(epochs)):
        net.train()
        running_loss=0.0
        correct,total=0,0
        for i,data in enumerate(train_loader,start=0):
            inputs,labels=data
            inputs=inputs.to(device).float()
            labels=labels.to(device).long()
            
            #train
            optimizer.zero_grad()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            loss.backward()
            optimizer.step()

            running_loss+=loss.item()
            #training accuracy
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        train_loss.append(running_loss/len(train_loader))
        train_acc.append(correct/total)
        print(f"epoch {epoch} --> TRAIN loss: {running_loss/len(train_loader):.6f}, TRAIN accuracy: {correct/total:.2f}")

        #eval on test
        net.eval()
        running_loss=0.0
        correct,total=0,0
        for inputs,labels in test_loader:
            inputs,labels=inputs.to(device).float(),labels.to(device).long()
            outputs=net.forward(inputs)
            loss=criterion(outputs,labels)
            running_loss+=loss.item()

            #test acc
            _,predicted=torch.max(outputs,1)
            total+=labels.size(0)
            correct+=(predicted==labels).sum()
        test_loss.append(running_loss/len(test_loader))
        test_acc.append(correct/total)
        print(f"epoch {epoch} --> TEST loss: {running_loss/len(train_loader):.2f}, TEST accuracy: {correct/total:.2f}")

    return train_loss,train_acc,test_loss,test_acc


batch_size=75
test_data = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size,
                         shuffle=False,num_workers=2
                         )

train_data = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_data,
                          batch_size=batch_size,
                          drop_last=False,
                          shuffle=True,num_workers=2
                          )

In [16]:
net=Net(dropouts=[0.4,0.4])
train_loss,train_acc,test_loss,test_acc=train_test(net,60,train_loader=train_loader,test_loader=test_loader,device=DEVICE)

  0%|          | 0/60 [00:00<?, ?it/s]

epoch 0 --> TRAIN loss: 3.507743, TRAIN accuracy: 0.14


  2%|▏         | 1/60 [00:35<35:13, 35.82s/it]

epoch 0 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 1 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


  3%|▎         | 2/60 [00:57<26:47, 27.72s/it]

epoch 1 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 2 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


  5%|▌         | 3/60 [01:20<24:16, 25.55s/it]

epoch 2 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 3 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


  7%|▋         | 4/60 [01:44<23:04, 24.72s/it]

epoch 3 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 4 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


  8%|▊         | 5/60 [02:07<22:14, 24.26s/it]

epoch 4 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 5 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 10%|█         | 6/60 [02:30<21:30, 23.91s/it]

epoch 5 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 6 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 12%|█▏        | 7/60 [02:54<20:57, 23.73s/it]

epoch 6 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 7 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 13%|█▎        | 8/60 [03:17<20:31, 23.69s/it]

epoch 7 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 8 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 15%|█▌        | 9/60 [03:41<20:00, 23.53s/it]

epoch 8 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 9 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 17%|█▋        | 10/60 [04:04<19:27, 23.36s/it]

epoch 9 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 10 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 18%|█▊        | 11/60 [04:27<19:00, 23.27s/it]

epoch 10 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 11 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 20%|██        | 12/60 [04:51<18:50, 23.55s/it]

epoch 11 --> TEST loss: 0.29, TEST accuracy: 0.10
epoch 12 --> TRAIN loss: 1.945910, TRAIN accuracy: 0.14


 20%|██        | 12/60 [05:13<20:52, 26.09s/it]


KeyboardInterrupt: 