In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io
import scipy.signal
from einops import reduce, rearrange
import numpy as np

In [2]:
from torch_geometric.data import InMemoryDataset, Data, DataLoader
from Electrodes import Electrodes
from tqdm import tqdm
class DEAPDatasetEEGFeatures(InMemoryDataset):
    
  def __init__(self, root, raw_dir,processed_dir, transform=None, pre_transform=None,include_edge_attr = False, undirected_graphs = True, add_global_connections=True, participant_from=1, participant_to=32,window_size=128, n_videos=40):
      self._raw_dir = raw_dir
      self._processed_dir = processed_dir
      self.participant_from = participant_from
      self.participant_to = participant_to
      self.n_videos = n_videos
      self.window_size = window_size
      # Whether or not to include edge_attr in the dataset
      self.include_edge_attr = include_edge_attr
      # If true there will be 1024 links as opposed to 528
      self.undirected_graphs = undirected_graphs
      # Instantiate class to handle electrode positions
      print('Using global connections' if add_global_connections else 'Not using global connections')
      self.electrodes = Electrodes(add_global_connections, expand_3d = False)
      super(DEAPDatasetEEGFeatures, self).__init__(root, transform, pre_transform)
      self.data, self.slices = torch.load(self.processed_paths[0])
      

  @property
  def raw_dir(self):
      return f'{self.root}/{self._raw_dir}'

  @property
  def processed_dir(self):
      return f'{self.root}/{self._processed_dir}'

  @property
  def raw_file_names(self):
      raw_names = [f for f in os.listdir(self.raw_dir)]
      raw_names.sort()
      return raw_names

  @property
  def processed_file_names(self):
      if not os.path.exists(self.processed_dir):
        os.makedirs(self.processed_dir)
      file_name = f'{self.participant_from}-{self.participant_to}' if self.participant_from is not self.participant_to else f'{self.participant_from}'
      return [f'deap_processed_graph.{file_name}.dataset']

  def process(self):
        # Number of nodes per graph
        n_nodes = len(self.electrodes.positions_3d)

        if self.undirected_graphs:
            source_nodes, target_nodes = np.repeat(np.arange(0,n_nodes),n_nodes), np.tile(np.arange(0,n_nodes),n_nodes)
        else:
            source_nodes, target_nodes = np.tril_indices(n_nodes,n_nodes)
        
        edge_attr = self.electrodes.adjacency_matrix[source_nodes,target_nodes]
        
        # Remove zero weight links
        mask = np.ma.masked_not_equal(edge_attr, 0).mask
        edge_attr,source_nodes,target_nodes = edge_attr[mask], source_nodes[mask], target_nodes[mask]

        edge_attr, edge_index = torch.FloatTensor(edge_attr), torch.tensor([source_nodes,target_nodes], dtype=torch.long)
        
        # Expand edge_index and edge_attr to match windows
        e_edge_index = edge_index.clone()
        e_edge_attr = edge_attr.clone()
        for i in range(128*60//self.window_size-1):
            a = edge_index + e_edge_index.max() + 1
            e_edge_index = torch.cat([e_edge_index,a],dim=1)
            e_edge_attr = torch.cat([e_edge_attr,edge_attr],dim=0)

        # List of graphs that will be written to file
        data_list = []
        pbar = tqdm(range(self.participant_from,self.participant_to+1))
        for participant_id in pbar:
            raw_name = [e for e in self.raw_file_names if str(participant_id).zfill(2) in e][0]
            pbar.set_description(raw_name)
            # Load raw file as np array
            participant_data = scipy.io.loadmat(f'{self.raw_dir}/{raw_name}')
            signal_data = torch.FloatTensor(participant_data['data'][:,:32,128*3:])
            processed = []
            for i, video in enumerate(signal_data[:self.n_videos,:,:]):
                video = video.reshape(-1,128)
                n = video[0].shape[-1]
                
                # Differential entropy features
#                 de_features = []
#                 for window in video:
                    
#                     fourier = np.fft.rfft(window)
#                     real_absolute_fft = 2.0/n * np.abs(fourier[:n//2])
#                     freq = np.fft.rfftfreq(n, d=1./128)

#                     delta_mask = np.logical_and(freq > 0.5 ,freq < 4)[:64]
#                     delta_values = real_absolute_fft[delta_mask]
#                     delta_entropy = scipy.stats.entropy(delta_values)

#                     theta_mask = np.logical_and(freq > 4 ,freq < 8)[:64]
#                     theta_values = real_absolute_fft[theta_mask]
#                     theta_entropy = scipy.stats.entropy(theta_values)

#                     alpha_mask = np.logical_and(freq > 8 ,freq < 12)[:64]
#                     alpha_values = real_absolute_fft[alpha_mask]
#                     alpha_entropy = scipy.stats.entropy(alpha_values)

#                     beta_mask = np.logical_and(freq > 12 ,freq < 30)[:64]
#                     beta_values = real_absolute_fft[beta_mask]
#                     beta_entropy = scipy.stats.entropy(beta_values)

#                     gamma_mask = np.logical_and(freq > 30 ,freq < 45)[:64]
#                     gamma_values = real_absolute_fft[gamma_mask]
#                     gamma_entropy = scipy.stats.entropy(gamma_values)
                
#                     window_features = torch.FloatTensor([delta_entropy, theta_entropy, alpha_entropy, beta_entropy, gamma_entropy])
#                     de_features.append(window_features)
                
#                 de_features = torch.stack(de_features)
#                 node_features = de_features
            
                # Power spectral density for each channel
                psd = scipy.signal.periodogram(video)[1]
                node_features = psd
                
                # Raw signals 
                # node_features = video
                
                # Should we add MinMax/Z scaler?
                data = Data(x=torch.FloatTensor(node_features),edge_attr=e_edge_attr,edge_index=e_edge_index, y=torch.FloatTensor([participant_data['labels'][i]])) if self.include_edge_attr else Data(x=torch.FloatTensor(node_features), edge_index=e_edge_index, y=torch.FloatTensor([participant_data['labels'][i]]))
                data_list.append(data) 
               
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

In [3]:
# Constants used to define data paths
ROOT_DIR = './'
RAW_DIR = 'data/matlabPREPROCESSED'
PROCESSED_DIR = 'data/graphProcessedData'

dataset = DEAPDatasetEEGFeatures(root= ROOT_DIR, raw_dir= RAW_DIR, processed_dir= PROCESSED_DIR)
# Subject-independent classification
# DEPENDING ON WHAT DATA IS USED THE NETWORK LEARNS BETTER OR WORSE.
# SHOULD WE TRY TO HAVE A BALANCE TRAINING SET?
dataset = dataset.shuffle()
dataset[0]

Using global connections


Data(edge_index=[2, 11640], x=[1920, 65], y=[1, 4])

In [4]:
# 880 used for training, 220 validation and 180 testing
splt_idx = 1100
# splt_idx = 35

# 85% used for train/val
train_dataset = dataset[:splt_idx]
test_dataset = dataset[splt_idx:]

train_dataset,test_dataset

(DEAPDatasetEEGFeatures(1100), DEAPDatasetEEGFeatures(180))

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [6]:

from torch_geometric.nn import GCNConv
class Model(torch.nn.Module):
    def __init__(self, in_channels):
        super(Model, self).__init__()
        self.gconv1 = GCNConv(in_channels, 256, aggr='add')
        self.gconv2 = GCNConv(256, 128, aggr='add')
        self.gconv3 = GCNConv(128, 1, aggr='add')
        
        self.conv1 = nn.Conv1d(60, 8, 1, stride=1)
        
        self.lin1 = nn.Linear(8*32,32)
        self.lin2 = nn.Linear(32,1)
        
    def forward(self, batch):
        bs = len(torch.unique(batch.batch))
        x, edge_index = batch.x, batch.edge_index
        
        x = self.gconv1(x, edge_index)
        x = torch.tanh(x)
        x = self.gconv2(x, edge_index)
        x = torch.tanh(x)
        x = self.gconv3(x, edge_index)
        x = torch.tanh(x)
        x = F.dropout(x, p=0.3, training=self.training)
        x = x.reshape(bs,-1,32)

        x = self.conv1(x)
        x = torch.relu(x)
        x = x.reshape(bs,-1)
        
        x = self.lin1(x)
        x.sigmoid()
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin2(x)

        return x.view(-1)

In [21]:
# %%timeit


model = Model(dataset[0].x[1].shape[0])     
pytorch_total_params = sum(p.numel() for p in model.parameters())
print(f'Model parameter count: {pytorch_total_params}')

# optimizer = torch.optim.Adadelta(model.parameters(), lr=.1, rho=0.9, eps=1e-06, weight_decay=1e-5)
# optimizer = torch.optim.SGD(model.parameters(),lr=1e-1, weight_decay=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5, weight_decay=1e-1)
# optimizer = torch.optim.RMSprop(model.parameters(), lr=0.001, weight_decay=1e-6)

criterion = nn.MSELoss()

model = model.to(device)

def train(loader, target = 0):
    model.train()
    losses = []
    right = 0
    tot = 0
    for batch in loader:
        optimizer.zero_grad()
        batch = batch.to(device)
        y = batch.y[:,target] 
        out = model(batch)
        loss = criterion(out,y)
        loss.backward()
        losses.append(loss.item())
        optimizer.step()
        right += torch.eq(out > 5, y > 5).sum().item()
        tot += y.shape[0]
    return np.array(losses).mean(), right/tot

def test(loader,verbose=False, target = 0):
    model.eval()
    losses = []
    right = 0
    tot = 0
    for batch in loader:
        batch = batch.to(device)
        y = batch.y[:,0] # Arousal
        out = model(batch)
        if verbose:
            print(out,y)
        loss = criterion(out,y)
        losses.append(loss.item())
        right += torch.eq(out > 5, y > 5).sum().item()
        tot += y.shape[0]
    return np.array(losses).mean(), right/tot

best_val_loss = np.inf
esp = 0
MAX_ESP = 40

BS = 16

k_folds = 5
k_fold_size = len(train_dataset)/k_folds
current_fold = 0 # Ranges from 0 to k_folds-1

target = 0 # Valence
for epoch in range(1, 10000):    
    # KFOLD train/val split     
    if epoch %10 == 0:
        current_fold = current_fold+1 if current_fold < k_folds-1 else 0
    from_idx, to_idx = int(k_fold_size*current_fold), int(k_fold_size*(current_fold+1))
    kf_val_data = train_dataset[from_idx:to_idx]
    a = train_dataset[:from_idx]
    b = train_dataset[to_idx:]
    kf_train_data = a + b
    train_loader = DataLoader(kf_train_data, batch_size=BS, shuffle=False)
    val_loader = DataLoader(kf_val_data, batch_size=BS)
        
    # Training and validation
    train_loss, train_acc = train(train_loader, target = target)
    val_loss, val_acc = test(val_loader , target = target)
    print(f'Epoch {epoch} - Kfold:{current_fold} ;t loss: {train_loss:.5f} ;t acc: {train_acc:.2f} ;v loss: {val_loss:.5f} ;v acc: {val_acc:.2f}')

    # Early stopping and checkpoint
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        esp = 0
        torch.save(model.state_dict(),'./best_params') 
    else:
        esp += 1
        if esp >= MAX_ESP:
            break


print('Finished training')

Model parameter count: 58666
Epoch 1 - Kfold:0 ;t loss: 24.98012 ;t acc: 0.44 ;v loss: 19.26181 ;v acc: 0.48
Epoch 2 - Kfold:0 ;t loss: 14.39674 ;t acc: 0.44 ;v loss: 9.03530 ;v acc: 0.48
Epoch 3 - Kfold:0 ;t loss: 6.52481 ;t acc: 0.47 ;v loss: 4.56312 ;v acc: 0.48
Epoch 4 - Kfold:0 ;t loss: 4.94089 ;t acc: 0.51 ;v loss: 4.37789 ;v acc: 0.52
Epoch 5 - Kfold:0 ;t loss: 4.93372 ;t acc: 0.53 ;v loss: 4.37732 ;v acc: 0.52
Epoch 6 - Kfold:0 ;t loss: 4.89964 ;t acc: 0.52 ;v loss: 4.37830 ;v acc: 0.52
Epoch 7 - Kfold:0 ;t loss: 4.89751 ;t acc: 0.52 ;v loss: 4.38123 ;v acc: 0.54
Epoch 8 - Kfold:0 ;t loss: 4.78496 ;t acc: 0.54 ;v loss: 4.38649 ;v acc: 0.52
Epoch 9 - Kfold:0 ;t loss: 4.93338 ;t acc: 0.49 ;v loss: 4.39819 ;v acc: 0.52
Epoch 10 - Kfold:1 ;t loss: 4.79636 ;t acc: 0.50 ;v loss: 5.02512 ;v acc: 0.49
Epoch 11 - Kfold:1 ;t loss: 4.77461 ;t acc: 0.53 ;v loss: 5.05100 ;v acc: 0.49
Epoch 12 - Kfold:1 ;t loss: 4.66230 ;t acc: 0.54 ;v loss: 5.03992 ;v acc: 0.50
Epoch 13 - Kfold:1 ;t loss: 4

Epoch 105 - Kfold:0 ;t loss: 4.22951 ;t acc: 0.59 ;v loss: 4.02789 ;v acc: 0.57
Epoch 106 - Kfold:0 ;t loss: 4.12524 ;t acc: 0.61 ;v loss: 3.87939 ;v acc: 0.60
Epoch 107 - Kfold:0 ;t loss: 4.08284 ;t acc: 0.61 ;v loss: 3.91212 ;v acc: 0.57
Epoch 108 - Kfold:0 ;t loss: 4.11837 ;t acc: 0.60 ;v loss: 4.00846 ;v acc: 0.57
Epoch 109 - Kfold:0 ;t loss: 4.02578 ;t acc: 0.63 ;v loss: 3.95189 ;v acc: 0.59
Epoch 110 - Kfold:1 ;t loss: 4.00637 ;t acc: 0.64 ;v loss: 4.43183 ;v acc: 0.62
Epoch 111 - Kfold:1 ;t loss: 3.93885 ;t acc: 0.62 ;v loss: 4.35586 ;v acc: 0.61
Finished training


In [20]:
model.load_state_dict(torch.load('./best_params'))
test_loader = DataLoader(test_dataset, batch_size=BS)
loss, acc = test(train_loader, False)
print(f'Train loss: {loss} ; Train acc: {acc}')
loss, acc = test(val_loader, False)
print(f'Val loss: {loss} ; Val acc: {acc}')
loss, acc = test(test_loader, False)
print(f'Test loss: {loss} ; Test acc: {acc}')

# TODO: scheduler(?) Loss/acc records

Train loss: 0.9160561829805374 ; Train acc: 0.8488636363636364
Val loss: 0.8132187638963971 ; Val acc: 0.8409090909090909
Test loss: 7.466709534327189 ; Test acc: 0.5111111111111111
