<a href="https://colab.research.google.com/github/sccn/sound2meg/blob/main/Spatial_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
print(torch.__version__)
print(torch.cuda.is_available())
print(device)

2.0.0
True
cuda:0


In [1]:
# -*- coding: utf-8 -*-
"""Spatial_Attention.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/github/sccn/sound2meg/blob/main/Spatial_Attention.ipynb
"""

import torch
import os
from torch.autograd import Variable
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.optim as optim
from scipy.io import loadmat
from dataset_loading import Sound2MEGDataset
from torch.utils.data import Dataset, DataLoader, random_split
import gc
import sys
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
sys.tracebacklimit = 0

class SubjectLayer(nn.Module):
  def __init__(self):
    super(SubjectLayer, self).__init__()
    self.layers = []

    for i in range(124): #124 subjects
      layer = nn.Conv2d(270, 270, 1).to(device)
      self.layers.append(layer)
      
  def forward(self, x, s_idx):
    x = x.unsqueeze(1)
    for i in range(len(x)):
      x[i] = self.layers[s_idx[i]](x[i].clone())
    return x[:, 0, :, :]

class SpatialAttention(nn.Module):
  def __init__(self,in_channels, out_channels, K, path):
    super(SpatialAttention, self).__init__()
    self.out = out_channels
    self.input = in_channels
    self.K = K
    self.z = Parameter(torch.randn(self.out, K*K, dtype = torch.cfloat)/(32*32))
    self.z.requires_grad = True
    self.positions = loadmat(path + 'electrode_positions.mat')
    self.positions = self.positions['positions']
    self.x = torch.tensor(self.positions[:, 0]).to(device)
    self.y = torch.tensor(self.positions[:, 1]).to(device)
    self.cos_v = []
    self.sin_v = []
    self.cos = []
    self.sin = []
    for i in range(in_channels):
      self.cos_v = []
      self.sin_v = []
      for k in range(K):
        for l in range(K):
          self.cos_v.append(torch.cos(2*math.pi*(k*self.x[i]+l*self.y[i])))
          self.sin_v.append(torch.sin(2*math.pi*(k*self.x[i]+l*self.y[i])))
      self.cos.append(torch.stack(self.cos_v))
      self.sin.append(torch.stack(self.sin_v))
    self.cos = torch.stack(self.cos).to(device)
    self.sin = torch.stack(self.sin).to(device)
  def forward(self, X):
    N = X.size()[0]
    SA = torch.zeros(N, 270, 360).to(device)
    z_r = self.z.real
    z_i = self.z.imag
    a = (torch.mm(z_r.float(), torch.transpose(self.cos, 0, 1).float()) + torch.mm(z_i.float(), torch.transpose(self.sin, 0, 1).float())).to(device)
    exp2 = torch.sum(torch.exp(a[:, 0:self.out]), 1).to(device)
    exp2 = torch.transpose(exp2.unsqueeze(0), 0, 1)
    exp2 = torch.mm(exp2, torch.ones(1, 360).to(device))
    for i in range(N):
      exp1 = torch.mm(torch.exp(a), X[i]).to(device)
      SA[i] = (exp1/exp2).to(device)
      #SA[i] = SpatialAttentionSoftmax(self.input, self.out, X[i], a)
    return SA

class Net(nn.Module):
  def __init__(self, path, F):
    super(Net, self).__init__()
    self.SA = SpatialAttention(273, 270, 32, path)
    self.Subject = SubjectLayer().to(device)
    self.F = F
    self.conv1 = nn.Conv2d(270, 270, (1, 1)).to(device)
    self.conv2 = nn.Conv2d(320, 640, (1, 1)).to(device)
    self.conv3 = nn.Conv2d(640, self.F, (1, 1)).to(device)
    self.gelu = nn.GELU().to(device)
    self.loop_convs = []
    self.loop_batchnorms = []
    self.loop_gelus = []
    self.loop_glus = []
    for k in range(1, 6):
      p = pow(2,(2*k)%5)
      q = pow(2,(2*k+1)%5)
      self.convs = []
      self.batchnorms = []
      self.gelus = []
      self.convs.append(nn.Conv2d(320, 320, (3, 1), dilation = p, padding = (p, 0)).to(device))
      self.convs.append(nn.Conv2d(320, 320, (3, 1), dilation = q, padding = (q, 0)).to(device))
      self.convs.append(nn.Conv2d(320, 640, (3, 1), dilation = 2, padding = (2, 0)).to(device))
      for i in range(2):
        self.batchnorms.append(nn.BatchNorm2d(320).to(device))
        self.gelus.append(nn.GELU().to(device))
      self.loop_convs.append(self.convs)
      self.loop_batchnorms.append(self.batchnorms)
      self.loop_gelus.append(self.gelus)
      self.loop_glus.append(nn.GLU().to(device))
    self.loop_convs[0][0] = nn.Conv2d(270, 320, (3, 1), dilation = 1, padding = (1, 0)).to(device)
  def forward(self, x, s_idx):
    x = self.SA(x).unsqueeze(3)
    x = self.conv1(x)
    x = self.Subject(x, s_idx)
    for i in range(5):
      if i == 0:
        x = self.loop_convs[0][0](x)
        x = self.loop_batchnorms[0][0](x)
        x = self.loop_gelus[0][0](x)
        x = self.loop_convs[0][1](x)
        x = self.loop_batchnorms[0][1](x)
        x = self.loop_gelus[0][1](x)
        x = self.loop_convs[0][2](x)
        x = torch.transpose(x, 3, 1)
        x = self.loop_glus[0](x)
        x = torch.transpose(x, 3, 1)
      else:
        x1 = self.loop_convs[i][0](x)
        x1 = self.loop_batchnorms[i][0](x)
        x1 = self.loop_gelus[i][0](x)
        x2 = x + x1
        x3 = self.loop_convs[i][1](x2)
        x3 = self.loop_batchnorms[i][1](x3)
        x3 = self.loop_gelus[i][1](x3)
        x4 = x2 + x3
        x5 = self.loop_convs[i][2](x4)
        x5 = torch.transpose(x5, 3, 1)
        x5 = self.loop_glus[i](x5)
        x = torch.transpose(x5, 3, 1)
    x_out = self.conv2(x)
    x_out = self.gelu(x_out)
    x_out = self.conv3(x_out)
    return x_out

# def CLIP_loss(Z, Y):
#   N = Y.size(dim = 0)
#   #inner_product = torch.zeros(N, N)
#   log_softmax = torch.zeros(N).to(device)
#   Z_row = torch.reshape(Z, (N, -1)).to(device)
#   Y_row = torch.reshape(Y, (N, -1)).to(device)
#   inner_product = (torch.mm(Z_row, torch.transpose(Y_row, 1, 0))).to(device)
#   for i in range(N):
#     inn = inner_product[i, :].to(device)
#     log_softmax[i] = torch.log(nn.functional.softmax(inn, -1).clamp(min=1e-4))[i]
#   return sum(-1*log_softmax)

def CLIP_loss(Z, Y, device):
    '''
    New loss using cross entropy implementation
    '''
    N = Y.size(dim = 0) # batch size
    log_softmax = torch.zeros(N, device=device)
    Z_row = torch.reshape(Z, (N, -1)) # flatten to be N x F
    Y_row = torch.reshape(Y, (N, -1)) # flatten to be N x F
    inner_product = torch.mm(Z_row, Y_row.T)/(N*N) # N x N. The normalization?

    target = torch.arange(N, device=device)
    loss_brain = torch.nn.functional.cross_entropy(inner_product, target)
    loss_sound = torch.nn.functional.cross_entropy(inner_product.T, target)
    loss = (loss_brain + loss_sound)/2
    
    return loss

In [None]:
dataset = Sound2MEGDataset('/expanse/projects/nsg/external_users/public/arno/')
training_data, validation_data, test_data = random_split(dataset, [11497, 3285, 1642], generator=torch.Generator().manual_seed(42))
Training_Data_Batches = DataLoader(training_data, batch_size = 128, shuffle = True)
Validation_Data_Batches = DataLoader(validation_data, batch_size = 128, shuffle = True)
BrainModule = Net('/expanse/projects/nsg/external_users/public/arno/', 120)
BrainModule.to(device)
optimizer = optim.Adam(BrainModule.parameters(), lr = 0.0003)
loss_train = []
loss_val = []

for i in range(50):
  loss_t = 0
  loss_v = 0
  for MEG, WAV, Sub in Training_Data_Batches:
    Sub = Sub.tolist()
    optimizer.zero_grad()
    Z = BrainModule(MEG.to(device), Sub)
    Z = Z[:, :, :, 0]
    loss = CLIP_loss(Z.float(), WAV.abs().float().to(device), device)
    loss.backward()
    loss_t = loss_t + loss.item()
    optimizer.step()
  loss_train.append(loss_t/(len(Training_Data_Batches)))
  print("Train loss:",loss_train)
  for MEG_val, WAV_val, Sub_val in Validation_Data_Batches:
    with torch.no_grad():
      Z_val = BrainModule(MEG_val.to(device), Sub_val)
      loss = CLIP_loss(Z_val.float(), WAV_val.abs().float().to(device), device)
    loss_v = loss_v + loss.item()
  loss_val.append(loss_v/len(Validation_Data_Batches))
  print("Val loss:",loss_val)
  gc.collect()
  torch.cuda.empty_cache()

print(loss_train)
print(loss_val)


Train loss: [4.849829763836331]
Val loss: [4.836285150968111]
Train loss: [4.849829763836331, 4.849829408857557]
Val loss: [4.836285150968111, 4.836284765830407]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097]
Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544]
Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884]
Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697, 4.83628170306866]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884, 4.849751334720188]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884, 4.849751334720188, 4.84974873330858]
Val loss: [4.836285150968111, 4.8362847

Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697, 4.83628170306866, 4.836259016623864, 4.836259255042443, 4.836244289691631, 4.836200108894935, 4.836186849153959, 4.836246142020593, 4.836316090363723, 4.8362516439878025, 4.836211626346294, 4.836219659218421, 4.8363142380347615, 4.836036315331092, 4.836402764687171, 4.836247444152832, 4.836460113525391, 4.836268003170307]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884, 4.849751334720188, 4.84974873330858, 4.849783256318834, 4.8497604264153376, 4.849728107452393, 4.849700111813015, 4.849664497375488, 4.849654097027249, 4.849597915013631, 4.849529012044271, 4.849509366353353, 4.8494378884633385, 4.8494542492760555, 4.849411741892497, 4.849388064278497, 4.849312114715576, 4.849234469731649]
Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697, 4.83628170306866, 4.836259016623864, 4.836259255042443, 4.83624428969

Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884, 4.849751334720188, 4.84974873330858, 4.849783256318834, 4.8497604264153376, 4.849728107452393, 4.849700111813015, 4.849664497375488, 4.849654097027249, 4.849597915013631, 4.849529012044271, 4.849509366353353, 4.8494378884633385, 4.8494542492760555, 4.849411741892497, 4.849388064278497, 4.849312114715576, 4.849234469731649, 4.849193599489, 4.84919802347819, 4.849103736877441, 4.849090327156914, 4.849005450142754, 4.848889912499322, 4.848411189185248, 4.846921814812554, 4.8452170001135935]
Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697, 4.83628170306866, 4.836259016623864, 4.836259255042443, 4.836244289691631, 4.836200108894935, 4.836186849153959, 4.836246142020593, 4.836316090363723, 4.8362516439878025, 4.836211626346294, 4.836219659218421, 4.8363142380347615, 4.836036315331092, 4.836402764687171, 4.836247444152832, 4.836460113525391, 4.83626

Val loss: [4.836285150968111, 4.836284765830407, 4.836284490732046, 4.836283206939697, 4.83628170306866, 4.836259016623864, 4.836259255042443, 4.836244289691631, 4.836200108894935, 4.836186849153959, 4.836246142020593, 4.836316090363723, 4.8362516439878025, 4.836211626346294, 4.836219659218421, 4.8363142380347615, 4.836036315331092, 4.836402764687171, 4.836247444152832, 4.836460113525391, 4.836268003170307, 4.836495142716628, 4.836570336268498, 4.836349597344031, 4.836390256881714, 4.836256155600915, 4.836112902714656, 4.836172562379104, 4.835977994478666, 4.836156166516817, 4.836175845219539, 4.836454978355994, 4.836473685044509, 4.838052382835975, 4.838922867408166, 4.840368270874023, 4.843229733980619]
Train loss: [4.849829763836331, 4.849829408857557, 4.849828423394097, 4.849826018015544, 4.849814383188884, 4.849751334720188, 4.84974873330858, 4.849783256318834, 4.8497604264153376, 4.849728107452393, 4.849700111813015, 4.849664497375488, 4.849654097027249, 4.849597915013631, 4.8495