<a href="https://colab.research.google.com/github/sccn/sound2meg/blob/main/Spatial_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
import os
from torch.autograd import Variable
import math
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.optim as optim
from scipy.io import loadmat
from dataset_loading import Sound2MEGDataset
from torch.utils.data import Dataset, DataLoader, random_split

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [None]:
#import scipy.io
#data = scipy.io.loadmat('/content/drive/MyDrive/data.mat')

In [None]:
#x = data['data']
#x = x[:, 1:3600:10, :]
#print(x.shape)   

(273, 360, 7)


In [None]:
def cos_vector(k, K, x, y):
  cos_v = torch.zeros(273, K)
  for l in range(K):
    cos_v[:,l] = torch.cos(2*math.pi*(k*x+l*y))
  return cos_v
def sin_vector(k, K, x, y):
  sin_v = torch.zeros(273, K)
  for l in range(K):
    sin_v[:,l] = torch.sin(2*math.pi*(k*x+l*y))
  return sin_v

In [4]:
def SpatialAttentionSoftmax(in_channels, out_channels, X, a):
  #z_r = z.real
  #z_i = z.imag
  #a = (torch.mm(z_r.float(), torch.transpose(cos, 0, 1).float()) + torch.mm(z_i.float(), torch.transpose(sin, 0, 1).float())).to(device)
  SA = torch.randn(out_channels, 360)
  for j in range(out_channels):
    exp1 = torch.mm(torch.exp(a[j, :]).unsqueeze(0), X)
    exp2 = torch.sum(torch.exp(a[j, 0:out_channels]))
    SA[j] = exp1/exp2
  return SA

In [12]:
class SpatialAttention(nn.Module):
  def __init__(self,in_channels, out_channels, K, path):
    super(SpatialAttention, self).__init__()
    self.out = out_channels
    self.input = in_channels
    self.K = K
    self.z = Parameter(torch.randn(self.out, K*K, dtype = torch.cfloat)/(32*32))
    self.z.requires_grad = True
    self.positions = loadmat(path + 'electrode_positions.mat')
    self.positions = self.positions['positions']
    self.x = torch.tensor(self.positions[:, 0]).to(device)
    self.y = torch.tensor(self.positions[:, 1]).to(device)
    self.cos_v = []
    self.sin_v = []
    self.cos = []
    self.sin = []
    for i in range(in_channels):
      self.cos_v = []
      self.sin_v = []
      for k in range(K):
        for l in range(K):
          self.cos_v.append(torch.cos(2*math.pi*(k*self.x[i]+l*self.y[i])))
          self.sin_v.append(torch.sin(2*math.pi*(k*self.x[i]+l*self.y[i])))
      self.cos.append(torch.stack(self.cos_v))
      self.sin.append(torch.stack(self.sin_v))
    self.cos = torch.stack(self.cos).to(device)
    self.sin = torch.stack(self.sin).to(device)
  def forward(self, X):
    N = X.size()[0]
    SA = torch.zeros(N, 270, 360)
    z_r = self.z.real
    z_i = self.z.imag
    a = (torch.mm(z_r.float(), torch.transpose(self.cos, 0, 1).float()) + torch.mm(z_i.float(), torch.transpose(self.sin, 0, 1).float())).to(device)
    for i in range(N):
      SA[i] = SpatialAttentionSoftmax(self.input, self.out, X[i], a)
    return SA

In [None]:
def SpatialAttentionFunc(in_channels, out_channels, X, z, K, cos, sin):
  a = torch.randn(out_channels, in_channels).to(device)
  #positions = loadmat('/content/drive/MyDrive/electrode_positions.mat')
  #positions = positions['positions']
  #x = torch.tensor(positions[:, 0])
  #y = torch.tensor(positions[:, 1])
  for j in range(out_channels):
    cos_sum = torch.zeros(in_channels).to(device)
    sin_sum = torch.zeros(in_channels).to(device)
    for k in range(K):
      z_r = z[j, k, :].real
      z_r = z_r.unsqueeze(0)
      z_i = z[j, k, :].imag
      z_i = z_i.unsqueeze(0)      
      #cos_k = torch.transpose(cos_vector(k, 32, x, y), 0, 1)
      #sin_k = torch.transpose(sin_vector(k, 32, x, y), 0, 1)
      cos_sum = cos_sum + torch.mm(z_r, cos[k])
      sin_sum = sin_sum + torch.mm(z_i, sin[k])
    a[j, :] = cos_sum + sin_sum 
  SA = torch.randn(out_channels, 360)
  for j in range(out_channels):
    exp1 = torch.mm(torch.exp(a[j, :]).unsqueeze(0), X)
    exp2 = torch.sum(torch.exp(a[j, 0:out_channels]))
    SA[j] = exp1/exp2
  return SA

In [6]:
class SubjectLayer(nn.Module):
  def __init__(self):
    super(SubjectLayer, self).__init__()
    self.layers = []

    for i in range(124): #124 subjects
      layer = nn.Conv2d(270, 270, 1)
      self.layers.append(layer)
      
  def forward(self, x, s_idx):
    for i in range(len(x)):
      x[i] = self.layers[s_idx[i]](x[i].clone())
    return x

In [None]:
subject = SubjectLayer()

x = torch.randn(3, 270, 360, 1)

print(x.shape)
output = subject(x, [0, 1, 2, 3])
print(output.shape)

torch.Size([3, 270, 360, 1])
torch.Size([3, 270, 360, 1])


In [None]:
class SpatialAttention(nn.Module):
  def __init__(self,in_channels, out_channels, K, path):
    super(SpatialAttention, self).__init__()
    self.positions = loadmat(path + 'electrode_positions.mat')
    self.positions = self.positions['positions']
    self.x = torch.tensor(self.positions[:, 0]).to(device)
    self.y = torch.tensor(self.positions[:, 1]).to(device)
    self.cos = []
    self.sin = []
    for k in range(32):
      self.cos.append(torch.transpose(cos_vector(k, 32, self.x, self.y), 0, 1))
      self.sin.append(torch.transpose(sin_vector(k, 32, self.x, self.y), 0, 1))
    self.cos = torch.stack(self.cos).to(device)
    self.sin = torch.stack(self.sin).to(device)
    self.out = out_channels
    self.input = in_channels
    self.K = K
    self.z = Parameter(torch.randn(out_channels, K, K, dtype = torch.cfloat)/(32*32))
    self.z.requiresGrad = True
  def forward(self, X):
    N = X.size()[0]
    SA = torch.zeros(N, 270, 360)
    for i in range(N):
      SA[i] = SpatialAttentionFunc(self.input, self.out, X[i], self.z, self.K, self.cos, self.sin)
    return SA

In [7]:
class Net(nn.Module):
  def __init__(self, path):
    super(Net, self).__init__()
    self.SA = SpatialAttention(273, 270, 32, path)
    self.Subject = SubjectLayer()
  def forward(self, y, s_idx):
    x1 = self.SA(y).unsqueeze(0)
    x2 = torch.permute(x1, (1, 2, 3, 0)) # subject attention?
    x3 = nn.Conv2d(270, 270, (1, 1))(x2)
    x = self.Subject(x3, s_idx)
    for k in range(1,6):
      p = pow(2,(2*k)%5)
      q = pow(2,(2*k+1)%5)
      if k == 1:
        x = nn.Conv2d(270, 320, (3, 1), dilation = 1, padding = (1, 0))(x)
        x = nn.BatchNorm2d(320)(x)
        x = nn.GELU()(x)
        x = nn.Conv2d(320, 320, (3, 1), dilation = 1, padding = (1, 0))(x)
        x = nn.BatchNorm2d(320)(x)
        x = nn.GELU()(x)
        x = nn.Conv2d(320, 640, (3, 1), dilation = 2, padding = (2, 0))(x)
        x = torch.transpose(x, 3, 1)
        x = nn.GLU()(x)
        x = torch.transpose(x, 3, 1)
      else:
        x1 = nn.Conv2d(320, 320, (3, 1), dilation = p, padding = (p, 0))(x)
        x1 = nn.BatchNorm2d(320)(x1)
        x1 = nn.GELU()(x1)
        x2 = x + x1
        x3 = nn.Conv2d(320, 320, (3, 1), dilation = q, padding = (q, 0))(x2)
        x3 = nn.BatchNorm2d(320)(x3)
        x3 = nn.GELU()(x3)
        x4 = x2 + x2
        x_out = nn.Conv2d(320, 640, (3, 1), dilation = 2, padding = (2, 0))(x4)
        x_out = torch.transpose(x_out, 3, 1)
        x_out = nn.GLU()(x_out)
        x_out = torch.transpose(x_out, 3, 1)
    x_out = nn.Conv2d(320, 640, (1, 1))(x_out)
    x_out = nn.GELU()(x_out)
    x_out = nn.Conv2d(640, 120, (1, 1))(x_out)
    return x_out

In [8]:
def CLIP_loss(Z, Y):
  N = Y.size(dim = 0)
  #inner_product = torch.zeros(N, N)
  log_softmax = torch.zeros(N).to(device)
  Z_row = torch.reshape(Z, (N, -1)).to(device)
  Y_row = torch.reshape(Y, (N, -1)).to(device)
  inner_product = (torch.mm(Z_row, torch.transpose(Y_row, 1, 0))/(N*N)).to(device)
  for i in range(N):
    inn = inner_product[i, :].to(device)
    log_softmax[i] = torch.log(nn.functional.softmax(inn, -1))[i]
  return sum(-1*log_softmax)

In [10]:
import sys
sys.tracebacklimit = 0

In [15]:
Dataset = Sound2MEGDataset('/content/drive/MyDrive/sound2meg/')
training_data, validation_data, test_data = random_split(Dataset, [0.7, 0.2, 0.1], generator=torch.Generator().manual_seed(42))
Training_Data_Batches = DataLoader(training_data, batch_size = 16, shuffle = True)
BrainModule = Net('/content/drive/MyDrive/sound2meg/')
BrainModule.to(device)
optimizer = optim.Adam(BrainModule.parameters(), lr = 0.0003)
loss_train = []
loss_val = []
for i in range(1):
  loss_t = 0
  for MEG, WAV, Sub in Training_Data_Batches:
    Sub = Sub.tolist()
    Z = BrainModule(MEG.to(device), Sub)
    Z = Z[:, :, :, 0]
    WAV.to(device)
    loss = CLIP_loss(Z, WAV.abs().to(device))
    torch.autograd.set_detect_anomaly(True)
    optimizer.zero_grad()
    loss.backward()
    print(loss.item())
    loss_t = loss_t + loss.item()
    optimizer.step()
  loss_train.append(loss_t/len(training_data))

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



IndexError: index 108 is out of bounds for axis 1 with size 99

During handling of the above exception, another exception occurred:

AttributeError: 'IndexError' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError


IndexError: ignored

In [25]:
MEG = torch.rand(16, 273, 360).to(device)
WAV = torch.rand(16, 120, 360).to(device)
Z = BrainModule(MEG, 16*[1])[:, :, :, 0]
loss = CLIP_loss(Z, WAV)
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [22]:
l = [12]
print(16*l)

[12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]


In [None]:
MEG_val = []
WAV_val = []
Sub_val = []
for i in range(len(validation_data)):
  MEG_val.append(validation_data[i][0])
  WAV_val.append(validation_data[i][1])
  Sub_val.append(validation_data[i][2])
MEG_val = torch.stack(MEG_val)
WAV_val = torch.stack(WAV_val)

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



KeyboardInterrupt

During handling of the above exception, another exception occurred:

AttributeError: 'KeyboardInterrupt' object has no attribute '_render_traceback_'

During handling of the above exception, another exception occurred:

AssertionError


KeyboardInterrupt: ignored