In [None]:
%%capture
!pip install ssqueezepy
!pip install timm
!pip install pytorch-lightning

In [None]:
main_path='../input/eeg-data-distance-learning-environment'

In [None]:
import pandas as pd
import os
df=pd.read_csv(os.path.join(main_path,'EEG_data.csv'))
df.head()

In [None]:
#we need first 16 channels to get raw data 14 channels

cols_remove=df.columns.tolist()[16:-1]
df=df.loc[:, ~df.columns.isin(cols_remove)]
df.columns = df.columns.str.strip('EEG.')
df.head()

In [None]:
df['subject_understood'].unique()

In [None]:
#now i need to reshape the data, into subjects,trials,channels,length
#for that first i will create groups based on subjects
groups=df.groupby(['subject_id','video_id'])
grp_keys=list(groups.groups.keys())
print(grp_keys)


In [None]:
grpno=grp_keys[0]
grp1=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)
label=grp1['subject_understood']
subject_id=grpno[0]
grp1=grp1.drop('subject_understood',axis=1)
grp1.head()

In [None]:
import mne
def convertDF2MNE(sub):
    info = mne.create_info(list(sub.columns), ch_types=['eeg'] * len(sub.columns), sfreq=128)
    info.set_montage('standard_1020')
    data=mne.io.RawArray(sub.T, info)
    data.set_eeg_reference()
    #data.filter(l_freq=1,h_freq=30)
    epochs=mne.make_fixed_length_epochs(data,duration=3,overlap=2)
    return epochs.get_data()

In [None]:
test=convertDF2MNE(grp1)
test.shape

In [None]:
128*3

In [None]:
!mkdir scaleogram

In [None]:
from glob import glob
import scipy.io
import torch.nn as nn
import torch
import numpy as np
import mne
from ssqueezepy import cwt
from ssqueezepy.visuals import plot, imshow
import os
import re
import pandas as pd

In [None]:
test[0][0].shape

In [None]:
Wx, scales = cwt(test[0], 'morlet')
Wx.shape

In [None]:
imshow(Wx[0])

In [None]:
%%capture
grpnos,labels,paths=[],[],[]
for i,grpno in enumerate(grp_keys):
    grp=groups.get_group(grpno).drop(['subject_id','video_id'],axis=1)
    label=int(grp['subject_understood'].unique())
    subject_id=grpno[0]
    grp=grp.drop('subject_understood',axis=1)
    data=convertDF2MNE(grp)#(trials, channels, length)
    for c,x in enumerate(data):#loop trials
        Wx, scales = cwt(x, 'morlet')
        Wx=np.abs(Wx)
        path=os.path.join('./scaleogram',f'subvideo_{grpno}/',)
        os.makedirs(path,exist_ok=True)
        path=path+f'trial_{c}.npy'
        np.save(path,Wx)
        
        grpnos.append(i)
        labels.append(label)
        paths.append(path)

In [None]:
Wx, scales = cwt(x, 'morlet')
imshow(Wx[0])

In [None]:
df_scale=pd.DataFrame(zip(paths,labels,grpnos),columns=['path','label','group'])
df_scale.head()

In [None]:
import numpy as np
from pytorch_lightning import seed_everything, LightningModule, Trainer
from sklearn.utils import class_weight
import torch.nn as nn
import torch
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.callbacks import EarlyStopping,ModelCheckpoint,LearningRateMonitor
from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau,CosineAnnealingWarmRestarts,OneCycleLR,CosineAnnealingLR
import torchvision
from sklearn.metrics import classification_report,f1_score,accuracy_score,roc_curve,auc,roc_auc_score
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from glob import glob
from PIL import Image
import cv2
from torch.utils.data import DataLoader, Dataset,ConcatDataset
import torchmetrics
import timm
import random

In [None]:
#read data from folders
class DataReader(Dataset):
    def __init__(self, dataset,aug=None):
        self.dataset = dataset
        self.aug=aug
    def __getitem__(self, index):
        x=self.dataset.path[index]
        y=self.dataset.label[index]
        x=np.load(x)
        if self.aug:
          if random.uniform(0, 1)>0.5:
            x=np.flip(x,-1)
          if random.uniform(0, 1)>0.5:
            x=np.flip(x,-2)
          # if random.uniform(0, 1)>0.5:
          #   c=np.arange(14)
          #   np.random.shuffle(c)
          #  x=x[c,:,:]
        x=(x - np.min(x)) / (np.max(x) - np.min(x))
       
        return x, y
    
    def __len__(self):
        return len(self.dataset)

In [None]:
test_loader=DataLoader(DataReader(df_scale,True), batch_size =8)
test_batch=next(iter(test_loader))
test_batch[0].shape ,test_batch[1].shape 

In [None]:
import timm
class OurModel(LightningModule):
    def __init__(self,train_split,val_split):
        super(OurModel,self).__init__()
        #architecute
        #lambda resnet
        
        self.train_split=train_split
        self.val_split=val_split
        #########TIMM#################
        model_name='resnest26d'
        self.model =  timm.create_model(model_name,pretrained=True)
        self.model.conv1[0]=nn.Conv2d(14, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
       

        self.fc1=nn.Linear(1000,500)
        self.relu=nn.ReLU()
        self.fc2= nn.Linear(500,250)
        self.fc3= nn.Linear(250,1)
        self.drp=nn.Dropout(0.25)
        #parameters
        self.lr=1e-3
        self.batch_size=16
        self.numworker=2
        self.criterion=nn.BCEWithLogitsLoss()
        self.metrics=torchmetrics.Accuracy()

        self.trainloss,self.valloss=[],[]
        self.trainacc,self.valacc=[],[]
        
        self.sub_pred=0
    def forward(self,x):
        x= self.model(x)
        x=self.fc1(x)
        x=self.relu(x)
        x=self.drp(x)
        x=self.fc2(x)
        x=self.relu(x)
        x=self.drp(x)
        x=self.fc3(x)
        return x

    def configure_optimizers(self):
        opt=torch.optim.AdamW(params=self.parameters(),lr=self.lr )
        return opt
        
    def train_dataloader(self):
        return DataLoader(DataReader(self.train_split,False), batch_size = self.batch_size, 
                          num_workers=self.numworker,pin_memory=True,shuffle=True)

    def training_step(self,batch,batch_idx):
        image,label=batch
        pred = self(image)
        loss=self.criterion(pred.flatten(),label.float()) #calculate loss
        acc=self.metrics(pred.flatten(),label)#calculate accuracy
        return {'loss':loss,'acc':acc}

    def training_epoch_end(self, outputs):
        loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        acc=torch.stack([x["acc"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        self.trainloss.append(loss)
        self.trainacc.append(acc)
        #print('training acc',acc)
        self.log('train_loss', loss)
        
    def val_dataloader(self):
        ds=DataLoader(DataReader(self.val_split), batch_size = self.batch_size,
                      num_workers=self.numworker,pin_memory=True, shuffle=False)
        return ds

    def validation_step(self,batch,batch_idx):
        image,label=batch
        pred = self(image)
        loss=self.criterion(pred.flatten(),label.float()) #calculate loss
        acc=self.metrics(pred.flatten(),label)#calculate accuracy
        return {'loss':loss,'acc':acc}

    def validation_epoch_end(self, outputs):
        loss=torch.stack([x["loss"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        acc=torch.stack([x["acc"] for x in outputs]).mean().detach().cpu().numpy().round(2)
        self.valloss.append(loss)
        self.valacc.append(acc)
        #print('validation acc',self.current_epoch,acc)
        self.log('val_loss', loss)
        self.log('val_acc', acc)
      
    def test_dataloader(self):
        ds=DataLoader(DataReader(self.val_split), batch_size = self.batch_size,
                      num_workers=self.numworker,pin_memory=True, shuffle=False)
        return ds
    def test_step(self,batch,batch_idx):
        image,label=batch
        pred = self(image)
        
        return {'label':label,'pred':pred}

    def test_epoch_end(self, outputs):

        label=torch.cat([x["label"] for x in outputs])
        pred=torch.cat([x["pred"] for x in outputs])
        acc=self.metrics(pred.flatten(),label)
        pred=pred.detach().cpu().numpy().ravel()
        label=label.detach().cpu().numpy().ravel()
        print('sklearn auc',roc_auc_score(label,pred))
        pred=np.where(pred>0.5,1,0).astype(int)
        print('torch acc',acc)
        print(classification_report(label,pred))
        print('sklearn',accuracy_score(label,pred))
        

In [None]:
from sklearn.model_selection import GroupKFold,LeaveOneGroupOut,StratifiedGroupKFold
gkf=StratifiedGroupKFold(5)
result=[]
valacc=[]
for train_index, val_index in gkf.split(df_scale.path,df_scale.label,  groups=df_scale.group):
    train_df=df_scale.iloc[train_index].reset_index(drop=True)
    val_df=df_scale.iloc[val_index].reset_index(drop=True)


    lr_monitor = LearningRateMonitor(logging_interval='epoch')
    gpu=-1 if torch.cuda.is_available() else 0
    gpup=16 if torch.cuda.is_available() else 32
    model=OurModel(train_df,val_df)
    trainer = Trainer(max_epochs=20, auto_lr_find=True, auto_scale_batch_size=True,
                        deterministic=True,
                        gpus=gpu,precision=gpup,
                        accumulate_grad_batches=2,
                        enable_progress_bar = True,
                        num_sanity_val_steps=0,
                        callbacks=[lr_monitor],
   
                        )
    trainer.fit(model)
    res=trainer.validate(model)
    result.append(res)
    valacc.append(model.valacc)
    trainer.test(model)
    

In [None]:
model.batch_size

In [None]:
plt.plot(model.trainacc,label='train')
plt.plot(model.valacc,label='val')
plt.legend()

In [None]:
val_df.label.unique(),val_df.group.unique()