In [1]:
import mne
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
import netrep.metrics as met
from typing import Literal, Tuple, Optional, List
import itertools

In [2]:
device=torch.device("cuda")

In [3]:
def get_events(file,event_duration):
    events=mne.events_from_annotations(file)
    epochs = mne.Epochs(file,events[0],events[1],tmin=-0.001, tmax=event_duration,event_repeated="drop",preload=True, picks=file.ch_names[:-2])
    epochs_scaled=mne.decoding.Scaler(epochs.info).fit_transform(epochs.get_data())
    
    signal_by_stim={}
    stimulus_ids=epochs.event_id
    
    for stim_id, stim_code in stimulus_ids.items():
        stim_events=epochs_scaled[epochs.events[:,2]==stim_code]
    
        signal_by_stim[stim_id]=stim_events
    
    return signal_by_stim

In [4]:
def flatten_samples(x):
    dims_count=len(x.shape)
    return x.reshape(*x.shape[:-(dims_count-1)], -1)

In [5]:
def get_mean_cov(x):
    new_x=flatten_samples(x)
    
    mu=np.mean(new_x, axis=0)
    cov=np.cov(new_x, rowvar=False)
    
    return np.expand_dims(mu,0),np.expand_dims(cov,0)

In [6]:
def average_per_person(folder, aggregate=lambda x: np.mean(x, axis=0)):
    per_session_avg_s16=[]
    per_session_avg_s32=[]
    for i in [1,2,3]:
        file=mne.io.read_raw_brainvision(f"neuro/{folder}/gonogo{i}.vhdr")
        events=get_events(file, 0.5)
        s16=events["Stimulus/S 16"]
        s32=events["Stimulus/S 32"]
        if aggregate:
            per_session_avg_s16.append(aggregate(s16))
            per_session_avg_s32.append(aggregate(s32))
        else:
            per_session_avg_s16.append(s16)
            per_session_avg_s32.append(s32)
        

    return np.array(per_session_avg_s16), np.array(per_session_avg_s32)

In [None]:
s16_all=[]
s32_all=[]
for i in tqdm(range(1, 27)):
    s16, s32= average_per_person(f"VN00{i}")
    s16_all.append(s16)
    s32_all.append(s32)

In [32]:
s16=np.array(s16_all)
s32=np.array(s32_all)

In [33]:
s16.shape

(26, 3, 28, 502)

In [34]:
np.save("S32.npy", s32)
np.save("S16.npy", s16)

In [63]:
s32=np.load("S32.npy")
s16=np.load("S16.npy")

# 2nd order isometric calculations

In [60]:
s16=s16.reshape(26, 3, 28*502)
s32=s32.reshape(26, 3, 28*502)

In [64]:
from sklearn.decomposition import PCA

In [66]:
s16_pca=PCA(n_components=28)
s16_comp=s16_pca.fit_transform(s16[0][0])

In [68]:
s16_comp.shape

(28, 28)

In [12]:
s32_pca=PCA(n_components=78)
s32_comp=s32_pca.fit_transform(s32)

In [13]:
s16_comp=s16_comp.reshape(78,1,78)
s32_comp=s32_comp.reshape(78,1,78)

In [14]:
Xs=np.concatenate((s16_comp, s32_comp))

In [35]:
Ys=np.array([0]*78 + [1]*78)

In [43]:
lm=met.LinearMetric(alpha=0,center_columns=False, score_method="euclidean")

In [44]:
dist_matrix, _ = lm.pairwise_distances(Xs)

Parallelizing 12090 distance calculations with 40 processes.


Computing distances: 100%|██████████| 12090/12090 [00:56<00:00, 215.08it/s]


In [45]:
from sklearn.manifold import MDS

In [47]:
mds=MDS(n_components=2,dissimilarity="precomputed")

In [48]:
points=mds.fit_transform(dist_matrix)



In [None]:
plt.scatter(points[:,0], points[:,1],c=Ys)

# Custom similarity loss

In [69]:
import torch
from torch import Tensor
from torch.nn.functional import pad

from typing import Literal, Tuple, Optional, List


class LinearMeasure(torch.nn.Module):
    def __init__(self,
                 alpha=1, center_columns=True, dim_matching='zero_pad', svd_grad=True, reduction='mean', no_svd=True):
        super(LinearMeasure, self).__init__()
        self.register_buffer('alpha', torch.tensor(alpha))
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.center_columns = center_columns
        self.svd_grad = svd_grad
        self.reduction = reduction
        self.no_svd=no_svd        

    
    def partial_fit(self, X: Tensor) -> Tuple[Tensor, Tensor]:
        """Computes the mean centered columns. Can be replaced later by whitening transform for linear invarariances."""
        if self.center_columns:
            mx = torch.mean(X, dim=1, keepdim=True)
        else:
            mx = torch.zeros(X.shape[2], dtype=X.dtype, device=X.device)
        wx = X - mx
        
        return mx, wx

    def fit(self, X: Tensor, Y: Tensor) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]:
        mx, wx = self.partial_fit(X)
        my, wy = self.partial_fit(Y)            

        if self.svd_grad:
            wxy = torch.bmm(wx.transpose(1, 2), wy)
            U, _, Vt = torch.linalg.svd(wxy)
        else:
            with torch.no_grad():
                wxy = torch.bmm(wx.transpose(1, 2), wy)
                U, _, Vt = torch.linalg.svd(wxy)
        wx = U
        wy = Vt.transpose(1, 2)
        return (mx, wx), (my, wy)

    def project(self, X: Tensor, m: Tensor, w: Tensor):
        if self.center_columns:
            return torch.bmm((X - m), w)
        else:
            return torch.bmm(X, w)

    def forward(self, X: Tensor, Y: Tensor):
        if X.shape[:-1] != Y.shape[:-1] or X.ndim != 3 or Y.ndim != 3:
            raise ValueError('Expected 3D input matrices to much in all dimensions but last.'
                             f'But got {X.shape} and {Y.shape} instead.')

        if X.shape[-1] != Y.shape[-1]:
            if self.dim_matching is None or self.dim_matching == 'none':
                raise ValueError(f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                                 f'Set dim_matching or change matrix dimensions.')
            elif self.dim_matching == 'zero_pad':
                size_diff = Y.shape[-1] - X.shape[-1]
                if size_diff < 0:
                    raise ValueError(f'With `zero_pad` dimension matching expected X dimension to be smaller then Y. '
                                     f'But got {X.shape} and {Y.shape} instead.')
                X = pad(X, (0, size_diff))
            elif self.dim_matching == 'pca':
                raise NotImplementedError
            else:
                raise ValueError(f'Unrecognized dimension matching {self.reduction}')
        
        if self.no_svd and X.shape[1]==1:
            mx, wx = self.partial_fit(X)
            my, wy = self.partial_fit(Y)

            x_norm=torch.linalg.norm(wx, dim=(1,2))
            y_norm=torch.linalg.norm(wy, dim=(1,2))

            norms= torch.sqrt( x_norm**2 + y_norm **2 - 2*(x_norm *y_norm))

        else:
            X_params, Y_params = self.fit(X, Y)
            norms = torch.linalg.norm(self.project(X, *X_params) - self.project(Y, *Y_params), ord="fro", dim=(1, 2))

        if self.reduction == 'mean':
            return norms.mean()
        elif self.reduction == 'sum':
            return norms.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return norms
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')


class EnergyMetric(torch.nn.Module):

    def __init__(self, n_iter=100, tol=1e-6, dim_matching='zero_pad', reduction='mean'):
        super(EnergyMetric,self).__init__()
        self.n_iter=n_iter
        self.tol=torch.tensor(tol)
        assert dim_matching in [None, 'none', 'zero_pad', 'pca']
        self.dim_matching = dim_matching
        self.reduction = reduction
        
    @torch.no_grad()
    def fit(self, X: torch.Tensor, Y:torch.Tensor):

        n_x=X.shape[2]
        n_y=Y.shape[2]
        
        X=X.repeat_interleave(n_y, dim=2).flatten(start_dim=1, end_dim=2)
        Y=Y.tile(dims=(1, 1, n_x, 1)).flatten(start_dim=1, end_dim=2)

        if X.shape[1] != Y.shape[1]:
            raise ValueError(f"After permutation got {X.shape} and {Y.shape}")
        
        w=torch.ones(X.shape[0], X.shape[1])
        
        batch_loss=[torch.mean(torch.linalg.norm(X-Y, dim=-1), dim =-1)]
        for i in range(self.n_iter):
            T=self.get_orth_matrix(w[:,:, None] *X, w[:, : , None]*Y)
            iter_result=torch.linalg.norm(X-torch.bmm(Y, T),dim=-1)
            batch_loss.append(torch.mean(iter_result, dim=-1))
            w=1/torch.maximum(torch.sqrt(iter_result), self.tol)

        return w, T, batch_loss, X, Y
        
    def get_orth_matrix(self, X:torch.Tensor , Y:torch.Tensor):
        U, _, Vt= torch.linalg.svd(torch.bmm(X.transpose(1,2), Y))
        return torch.bmm(Vt.transpose(1,2) , U.transpose(1,2))

    def get_dist_energy(self,X:torch.Tensor): 
        n=X.shape[2]
        combs = torch.combinations(torch.arange(n))
        X1= torch.flatten(X[:, :, combs[:, 0], :], start_dim=1, end_dim=2)
        X2=torch.flatten(X[:, :, combs[:, 1], :], start_dim=1, end_dim=2)
        
        return torch.mean(torch.linalg.norm(X2-X1, dim=-1), dim=-1)
        
        
    def forward(self,X: torch.Tensor , Y:torch.Tensor ):

        """Expected tensors to be of the form batch x class x repeats x activations"""
        
        if X.shape[:-2] != Y.shape[:-2] or X.ndim != 4 or Y.ndim != 4:
            raise ValueError('Expected 4D input matrices to much in all dimensions but last two.'
                             f'But got {X.shape} and {Y.shape} instead.')

        if X.shape[-1] != Y.shape[-1]:
            if self.dim_matching is None or self.dim_matching == 'none':
                raise ValueError(f'Expected same dimension matrices got instead {X.shape} and {Y.shape}. '
                                 f'Set dim_matching or change matrix dimensions.')
            elif self.dim_matching == 'zero_pad':
                size_diff = Y.shape[-1] - X.shape[-1]
                if size_diff < 0:
                    raise ValueError(f'With `zero_pad` dimension matching expected X dimension to be smaller then Y. '
                                     f'But got {X.shape} and {Y.shape} instead.')
                X = pad(X, (0, size_diff))
            elif self.dim_matching == 'pca':
                raise NotImplementedError
            else:
                raise ValueError(f'Unrecognized dimension matching {self.reduction}')
                
        #return self.fit(X,Y)
        w,T,fit_loss, X_prod, Y_prod= self.fit(X,Y)
        
        e_xx=self.get_dist_energy(X)
        e_yy=self.get_dist_energy(Y)
        Y_proj=torch.bmm(Y_prod, T)
        e_xy=torch.mean(torch.linalg.norm(X_prod-Y_proj, dim=-1),dim=-1)

        norms= torch.sqrt(torch.nn.functional.relu(e_xy-0.5*(e_xx+e_yy)))

        if self.reduction == 'mean':
            return norms.mean()
        elif self.reduction == 'sum':
            return norms.sum()
        elif self.reduction == 'none' or self.reduction is None:
            return norms
        else:
            raise ValueError(f'Unrecognized reduction {self.reduction}')


In [77]:
s32.shape

(26, 3, 28, 502)

In [None]:
x=torch.rand(1, 26*3, 28*502).to(device)
y=torch.rand(1, 26*3, 28*502).to(device)

In [73]:
lm=LinearMeasure(svd_grad=False)
opt=torch.optim.Adam(model.parameters())

In [None]:
s16_t=torch.tensor(s16_across_trials, dtype=torch.float32).to(device).requires_grad_(True)
s32_t=torch.tensor(s32_across_trials, dtype=torch.float32).to(device)

In [None]:
s16_lower_dim=model(s16_t)
s32_lower_dim=model(s32_t)

In [None]:
loss=lm(s16_t,s32_t)

In [None]:
train_loss_batch=[]
test_loss_batch=[]
for epochs in range(100):
    with torch.no_grad():
        test_out=model(s16_t)
        test_loss=crit(test_out[:,0], y_test_gpu)
        test_loss_batch.append(test_loss.item())   
        
    opt.zero_grad()
    out=model(x_batch)
    loss=crit(out[:,0], y_batch)
    loss.backward()
    opt.step()
    
    train_loss_batch.append(loss.item())