Going to recreate 16 ECoG channels:
Autoencoder = PCA; Walk down from 129 to 16 and from 16 to 129.
Overall results: the error does not converge in terms of training loss - a continuous oscillation of error, but the error is under 0.1 loss, which is ok.

Encoder, Decoder used to reduce and demap target and estimate respectively.

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
!pip install torch torchvision mat73 pymatreader matplotlib tensorboard mne joblib

In [None]:
import torch
from torch.autograd import Variable
import statistics
import matplotlib.pyplot as plt
import numpy as np
import os

In [None]:
#@title Autoencoder Class
#class definition of model

class Encoder(torch.nn.Module):
  def __init__(self,input_size, bottleneck_size):
    super(Encoder,self).__init__()
    self.linear_list = torch.nn.ModuleList()
    self.activation_function = torch.nn.ModuleList()
    shift_len = 0

    #expect the shiftlen to be the pwr of 2 higher than input_size (ex. 129 should be 256)
    while (input_size >> shift_len) > 0:
      shift_len += 1
    shift_len -=2 # go from 64
    output_size = 1 << shift_len
    while output_size >= bottleneck_size:
      print(f"Encoder: Created Linear=({input_size,output_size})")
      self.linear_list.append(torch.nn.Linear(input_size,output_size))
      self.activation_function.append(torch.nn.Tanh()) #restricts the output(including from last layer) to be between -1 to 1
      input_size = output_size
      output_size >>= 1
    assert len(self.activation_function) == len(self.linear_list)

  def forward(self,X):
    reduced_data = X
    for i in range(len(self.activation_function)):
      reduced_data = self.linear_list[i](reduced_data)
      reduced_data = self.activation_function[i](reduced_data)
    return reduced_data

class Decoder(torch.nn.Module):
  def __init__(self,input_size, output_size):
    super(Decoder,self).__init__()
    self.linear_list = torch.nn.ModuleList()
    self.activation_function = torch.nn.ModuleList()
    shift_len = 0 

    while (input_size << 1) < 128 : 
      print(f"Decoder: Created Linear=({input_size,input_size << 1})")
      self.linear_list.append(torch.nn.Linear(input_size,input_size << 1))
      self.activation_function.append(torch.nn.Tanh())
      input_size = input_size << 1
    
    print(f"Decoder: Created Linear=({input_size,output_size})")
    self.linear_list.append(torch.nn.Linear(input_size,output_size))
    self.activation_function.append(torch.nn.Tanh())
    assert len(self.activation_function) == len(self.linear_list)
  def forward(self, X):
    encoded_data = X
    for i in range(len(self.activation_function)):
      encoded_data = self.linear_list[i](encoded_data)
      encoded_data = self.activation_function[i](encoded_data)
    return encoded_data

#expect auto-encoder to accept normalized(CRA_Sig(Y)) + dimensionally reduce
#normalized(CRA_Sig(Y)) => normalized(CRA_Sig(Y))_hat
class Autoencoder(torch.nn.Module):
  def __init__(self,input_size, bottleneck):
    super(Autoencoder,self).__init__()
    self.encoder = Encoder(input_size,bottleneck)
    self.decoder = Decoder(bottleneck,input_size)

  def forward(self,X):
    encoded = self.encoder(X)
    decoded = self.decoder(encoded)
    return decoded

In [None]:
#how we obtain the PCA/demapped results of the ECoG; only used after full training is done
@torch.no_grad()
def get_latent_space(model, X):
  return model.encoder(X)

@torch.no_grad()
def decode_latent(model, Y):
  return model.decoder(Y)

In [None]:
#@title Plotting Graphs-- General Methods
def plot_loss(model_loss,epoch_id, model_type, error_type,img_dir, range_key):
  assert len(model_loss) == 2
  print(f'Plotting {model_type}_{epoch_id} {error_type} Loss in Range {range_key} ....')
  x_vals = list(range(0,len(model_loss[0])))
  plot_error(x_vals,model_loss,epoch_id,model_type,error_type,img_dir,range_key)

def plot_error(x_vals, y_vals, epoch_i, model_type, error_type, images_dir, range_key):
  plt.figure(figsize=(10,10))
  counter = 0
  dataset_type = ["train","validation","test"]
  colors = ['r','b','y']
  style = ['-','--','.']
  for y_coors in y_vals:
    assert len(y_coors) == len(x_vals)
    plt.plot(x_vals,y_coors,colors[counter]+style[counter], label=dataset_type[counter])
    counter += 1
  plt.title(f'Autoencoder {model_type} Loss Curves until Epoch {epoch_i}')
  plt.xlabel('Epoch (#)')
  plt.ylabel(f'Error ({error_type})')
  plt.legend(loc='best')
  plt.savefig(f"{images_dir}/{error_type}_{range_key}_curve_{epoch_i}_{model_type}.png")
  plt.show()

def plot_test(model_loss,model_type,error_type,img_dir, range_key):
  plt.figure(figsize=(10,10))
  counts, bins = np.histogram(model_loss)
  plt.hist(bins[:-1], bins, weights=counts)
  plt.title(f'Frequency of RMSE Loss in Minibatches of {model_type} Set')
  plt.savefig(f"{img_dir}/{error_type}_{range_key}_histogram_{model_type}.png")
  plt.show()

In [None]:
#@title Autoencoder Training/Testing Methods

from enum import Enum
class Mode(Enum):
  TRAIN = 0
  VALIDATION = 1
  TEST = 2

def one_epoch(data, model, cost_function, optimizer, mode, epoch,feature):
  seq_count = 0
  loss_list = []
  assert mode != Mode.TEST
  print(f'Running {feature}_Epoch {epoch}....',end='')
  for seq in range(len(data)):
    #seq.shape = 8x100xN
    recreated_batch = model(data[seq])
    assert recreated_batch.shape == data[seq].shape
    train_loss = cost_function(recreated_batch,data[seq])
    loss_list.append(train_loss.item())
    if mode == Mode.TRAIN:
        train_loss.backward()
        optimizer.step()  
        optimizer.zero_grad()
    if mode == Mode.TRAIN:
      type_str = "Train"
    elif mode == Mode.VALIDATION:
      type_str = "Validation"
    print(f'\tBatch Loss = {loss_list[-1]}')
  total_loss = statistics.mean(loss_list)
  print(f"avg loss = {total_loss}")
  return total_loss

def train_driver(epochs, inp_dataset, optimizer, cost_function, model, checkpoint_dir, feature, img_dir,range_key,bookmark=15):
  train_loss = []
  val_loss = []
  inp_train_dataset = torch.split(inp_dataset[0],100,dim=1)
  inp_val_dataset = torch.split(inp_dataset[1],100,dim=1)

  for i in range(epochs + 1):
    model.train(True)
    train_loss.append(one_epoch(inp_train_dataset,model,cost_function,optimizer,Mode.TRAIN,i,feature))
    model.train(False)
    with torch.no_grad():
      model.eval()
      val_loss.append(one_epoch(inp_val_dataset, model, cost_function, optimizer, Mode.VALIDATION, i,feature))
    if i > 0 and i % bookmark == 0:
      print(f'Saving Autoencoder_{feature} Model....')
      torch.save(model.state_dict(),f'{checkpoint_dir}/autoencoder_{range_key}_{feature}_model.pt')
      plot_loss([train_loss,val_loss],i,feature,"MSE",img_dir,range_key)
  return model
  
def test_driver(inp_dataset, cost_function, model, checkpoint_dir,feature,img_dir, range_key):
  if model is None:
    model = torch.load(f'{checkpoint_dir}/autoencoder_{range_key}_{feature}_model.pt') #load model from file
  model.eval()
  test_loss = []
  inp_test_dataset = torch.split(inp_dataset,100,dim=1)
  print(f'Testing {feature}....')
  for batch in inp_test_dataset:
    estimate = model(batch)
    loss = cost_function(batch,estimate)
    print(f'\tBatch Loss = {loss.item()}')
    test_loss.append(loss.item())
  rmse_test_loss = np.sqrt(test_loss)
  print(f'Avg RMSE for {feature}: {statistics.mean(rmse_test_loss)}')
  plot_test(rmse_test_loss,feature,"RMSE",img_dir,range_key)

In [None]:
#@title Autoencoder Driver Methods
def auto_subdriver(current_directory,epochs,dataset_locations,input_size,hidden_size,learning_rate, feature, range_key):
  checkpt_dir = os.path.join(current_directory,f"autoencoder_{range_key}")
  if not os.path.exists(checkpt_dir):
    os.makedirs(checkpt_dir)
  img_dir = os.path.join(checkpt_dir,f"autoencoder_{range_key}_{feature}_results")
  if not os.path.exists(img_dir):
    os.makedirs(img_dir)
  
  print(f'Driving Autoencoder_{feature}')
  torch.autograd.set_detect_anomaly(True)
  model = Autoencoder(input_size,hidden_size)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  cost_function = torch.nn.MSELoss()

  #get normalized datasets
  print(dataset_locations[0])
  train_dataset = torch.load(dataset_locations[0]).float()
  print(f'Train Dim: {train_dataset.shape}')
  validation_dataset = torch.load(dataset_locations[1]).float()
  print(f'Validation Dim: {validation_dataset.shape}')
  test_dataset = torch.load(dataset_locations[2]).float()
  print(f'Test Dim: {test_dataset.shape}')

  model = train_driver(epochs,[train_dataset,validation_dataset],optimizer,cost_function,model,checkpt_dir,feature,img_dir,range_key,bookmark=15)
  with torch.no_grad():
    test_driver(test_dataset,cost_function,model,checkpt_dir,feature,img_dir,range_key)


def auto_driver(current_directory, epochs, data_key, range_key, specific_index = -1):
  deeper_level = current_directory + f'/checkpoint3/normalized_ecog_{range_key}'
  features = ['power','zero','mean','data']
  dataset_type = ['train','validation','test']
  locations = []
  for feature_type in features:
    filename_list = []
    data_dir = f"{deeper_level}/{feature_type}_dir/"
    for data_type in dataset_type:
      filename_list.append(data_dir + f'{data_key}_{feature_type}_{data_type}.pt')
    assert len(filename_list) == 3
    locations.append(filename_list)
  if specific_index >= 0:
    assert specific_index < len(locations)
    auto_subdriver(current_directory,epochs,locations[specific_index],129,16,1e-3,features[specific_index],range_key)
  else:
    for loc in range(len(locations)):
      auto_subdriver(current_directory,epochs,locations[loc],129,16,1e-3,features[loc],range_key)

In [None]:
auto_driver("/content/gdrive/MyDrive/BCI_Project/Datasets",150,'ECoG_Norm',"01",specific_index=0)

In [None]:
#@title Reduction Driver
"""
A method which stores reduced ecog data into files to reduce the potential ram usage when using get_latent_space
"""
import torch

def generate_dataset_locations(current_directory,data_key,features,norm_key):
  deeper_level = current_directory + f'/checkpoint3/normalized_ecog_{norm_key}'
  locations = []
  dataset_type = ["train","validation","test"]
  for feature_type in features:
    filename_list = []
    data_dir = f"{deeper_level}/{feature_type}_dir/"
    for data_type in dataset_type:
      filename_list.append(data_dir + f'{data_key}_{norm_key}_{feature_type}_{data_type}.pt')
    assert len(filename_list) == 3
    locations.append(filename_list)
  return locations

def generate_model_locations(current_directory,norm_key,features):
  model_list = []
  deeper_level = current_directory + f'/autoencoder_{norm_key}'
  for feature_type in features:
    model_list.append(f'{deeper_level}/autoencoder_{norm_key}_{feature_type}_model.pt')
  assert len(model_list)==4
  return model_list

def autoencoder_enc_dec(autoencoder_model,split_ecog_tensor,function):
  trans_ecog_split = []
  for item in split_ecog_tensor:
    trans_ecog_split.append(function(autoencoder_model,item))
  return torch.cat(trans_ecog_split,dim=1)


def reduction_sub_driver(ecog_data, model_loc):
  autoencode = Autoencoder(129,16)
  autoencode.load_state_dict(torch.load(model_loc))
  autoencode.eval()

  for dataset_type in range(len(ecog_data)):
    ecog_tensor = torch.load(ecog_data[dataset_type]).float()
    print(f"Shape of {ecog_data[dataset_type]}: {ecog_tensor.shape}")
    split_ecog_tensor = list(torch.split(ecog_tensor,100,dim=1))
    reduced_ecog_tensor = autoencoder_enc_dec(autoencode,split_ecog_tensor,get_latent_space)
    print(f"\tReduced Shape:{reduced_ecog_tensor.shape}")
    split_filename = ecog_data[dataset_type].split('/')
    filename = "reduced_" + split_filename[-1].split('.')[0]
    final_abs_filename = '/'.join(split_filename[:-1]) + '/'+filename+'.pt'
    print(f"\tSaving tensor in {final_abs_filename}")
    torch.save(reduced_ecog_tensor,final_abs_filename)

def reduction_caller(ecog_data_list, model_list, specific_index = -1):
  #driver to iterate
  assert len(ecog_data_list) == len(model_list)
  if specific_index < 0:
    for model_ind in range(len(model_list)):
      reduction_sub_driver(ecog_data_list[model_ind],model_list[model_ind])
  else:
    assert specific_index < len(ecog_data_list) 
    reduction_sub_driver(ecog_data_list[specific_index], model_list[specific_index])

def reduction_driver(current_directory,data_key,norm_key,specific_index_top = -1):
  features = ['power','zero','mean','data']
  locations = generate_dataset_locations(current_directory, data_key, features,norm_key)
  print(locations)
  model_list = generate_model_locations(current_directory,norm_key,features)
  reduction_caller(locations,model_list,specific_index=specific_index_top)

In [None]:
reduction_driver("/content/gdrive/MyDrive/BCI_Project/Datasets",'ECoG_Norm','01',specific_index_top=3)