In [1]:
import numpy as np
import math
import scipy.io as sio
import matplotlib.pyplot as  plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Ellipse
from scipy.stats import wilcoxon, ks_2samp, multivariate_normal
from scipy.interpolate import griddata
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from numba import njit, prange
import os
import re

np.random.seed(10)


In [2]:
def scatter_hist(x, y, ax, ax_histx, ax_histy, c = 'tab:blue', alpha = 1, htype = 'step'):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
    ax.scatter(x, y,s=1, color=c)

    # now determine nice limits by hand:
    binwidth = 0.2
    xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
    print(xymax)
    lim = (np.rint((xymax+1)/binwidth) + 1) * binwidth

    bins = np.arange(-lim, lim + binwidth, binwidth)
    axx = ax_histx.hist(x, bins=bins)
    ax_histx.set_ylim(0,(axx[0].max()//100+1)*100)
    axy = ax_histy.hist(y, bins=bins, orientation='horizontal')
    ax_histy.set_xlim(0,(axy[0].max()//100+1)*100)

def plot_gaussian_ellipse(mean, cov, ax, n_std=1.0, **kwargs):
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals = vals[order]
    vecs = vecs[:, order]

    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
    width, height = 2 * n_std * np.sqrt(vals)
    
    ellipse = Ellipse(xy=mean, width=width, height=height, angle=theta, **kwargs)
    ax.add_patch(ellipse)

def SUM_LOG_LIST(position):
    '''Given an integer n it recursively calculates log(n!)'''
    if position == 0:
        return np.array([0])
    if position == 1:
        return np.append(SUM_LOG_LIST(0), 0)
    new_list = SUM_LOG_LIST(position-1)
    return np.append(new_list, new_list[-1]+np.around(np.log(float(position)), 8))

def POISSON_2DCELL(tc_grid, max_firing=20):
    log_list = np.tile(SUM_LOG_LIST(max_firing)[:,None,None], tc_grid.shape)
    log_tc = np.around(np.log(tc_grid), 8)#, where=(mask==1), out = np.ones_like(tc_grid)*-100)
    log_likelihood = (np.array([(i*log_tc-tc_grid) for i in range(max_firing+1)])-log_list)
#     log_likelihood[0, mask==0]=0
#     log_likelihood[1:, mask==0]=-np.inf
    likelihood = np.exp(log_likelihood)
    likelihood = likelihood/np.sum(likelihood, axis=0)
    return likelihood    

    
# Define the network
class SimpleRegressor(nn.Module):
    def __init__(self):
        super(SimpleRegressor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
for i,input_filename in enumerate(os.listdir('./Contrast_cells/')):
    if '.mat' not in input_filename:
        continue
    print(input_filename)
    match = re.search(r"no\d+", input_filename)
    cell_no = match.group()

    output_filename = f'BDEvSSI_{cell_no}.npz'
    mat = sio.loadmat(f'./Contrast_cells/{input_filename}')

    bases = mat['U']
    imgs = mat['X']
    pcs = mat['X_lowd']
    fit = mat['f']
    spikes = mat['r']
    print(i, imgs.shape)
    #     plt.matshow(imgs)

    # calculate average and variance of pc1 and pc2
    mu0 = pcs[0].mean()
    mu1 = pcs[1].mean()
    sigma = np.cov(pcs[0], pcs[1])

    X = np.copy(pcs[:2]).T  # x and y
    y = np.copy(np.array(fit[0]))   # r
    # Optional: Normalize inputs for better training
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)

    X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
        X_scaled, y, test_size=0.2, random_state=42)

    # Convert to PyTorch tensors
    X_train = torch.tensor(X_train_np, dtype=torch.float32, requires_grad=True)
    X_test  = torch.tensor(X_test_np, dtype=torch.float32, requires_grad=True)
    y_train = torch.tensor(y_train_np, dtype=torch.float32).view(-1, 1)
    y_test  = torch.tensor(y_test_np, dtype=torch.float32).view(-1, 1)


    # Initialize model, loss, and optimizer
    model = SimpleRegressor()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    n_epochs = 200
    for epoch in range(n_epochs):
        model.train()
        optimizer.zero_grad()
        y_pred = model(X_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        if (epoch+1) % 20 == 0 or epoch == 0:
            model.eval()
            val_loss = criterion(model(X_test), y_test).item()
            print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

    # Inference: evaluate on a grid for visualization
    model.eval()

    #--------------------------------------------------------------------------------------
    # CALCULATE GRID - THIS GRID WILL BE USED THROUGHOUT

    grid_x, grid_y = np.meshgrid(np.linspace(-3, 3, 301), np.linspace(-3 ,3 , 301))
    grid_input = np.column_stack((grid_x.ravel(), grid_y.ravel()))
    grid_input_scaled = scaler.transform(grid_input)
    #--------------------------------------------------------------------------------------

    mus = pcs[:2].mean(axis=1)
    sigma = np.cov(pcs[:2])

    prior = multivariate_normal(mus, sigma)

    # grid = np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
    grid_prior = prior.pdf(grid_input).reshape(grid_x.shape)/prior.pdf(grid_input).reshape(grid_x.shape).sum()

    with torch.no_grad():
        preds = model(torch.tensor(grid_input, dtype=torch.float32)).numpy()[:,0]

    baseline = 1e-2
    tc = preds.reshape(grid_x.shape)+baseline
    likelihood = POISSON_2DCELL(tc)
    evidence = np.sum(likelihood*np.tile(grid_prior[None,:,:], (likelihood.shape[0],1,1)), axis=(-1,-2))
    posterior = likelihood*np.tile(grid_prior[None,:,:], (likelihood.shape[0],1,1))/\
                            np.tile(evidence[:,None,None], (1,*likelihood.shape[1:])) 
    posterior[posterior==0] = 1e-50
    posterior = posterior/np.tile(posterior.sum(axis=(1,2))[:,None,None], (1, *grid_x.shape))

    post_entropy = -np.sum(posterior*np.around(np.log2(posterior),8), axis=(1,2))
    prior_entropy = -np.sum(grid_prior*np.around(np.log2(grid_prior), 8))
    ssi = prior_entropy - np.sum(likelihood * np.tile(post_entropy[:,None,None], (1, *grid_x.shape)) ,axis=0)

    mse = np.empty_like(posterior[0])
    for i in range(likelihood.shape[1]):
        for j in range(likelihood.shape[2]):
            print(i,j, end = '\r')    
    #         post_ij = np.sum(prior*np.tile(likelihood[:,i,j][:,None,None], (1, *likelihood.shape[1:])), axis=0)
            post_ij = np.sum(posterior*np.tile(likelihood[:,i,j][:,None,None], (1,*likelihood.shape[1:])), axis=0)
            delta = (grid_x - grid_x[i,j])**2+(grid_y-grid_y[i,j])**2
            mse[i,j] = np.sum(post_ij*delta)
    rmse = np.sqrt(mse)
    np.savez(f'./Computed Contrast Cells/{output_filename}', mse=mse, posterior=posterior, ssi = ssi, grid_x=grid_x, grid_y=grid_y)

carlo_data_cellno10.mat
0 (108, 108, 3160)
Epoch 1/200, Train Loss: 4.3321, Val Loss: 4.4022
Epoch 20/200, Train Loss: 2.1287, Val Loss: 2.0252
Epoch 40/200, Train Loss: 1.7039, Val Loss: 1.6256
Epoch 60/200, Train Loss: 1.5373, Val Loss: 1.5195
Epoch 80/200, Train Loss: 1.4067, Val Loss: 1.4131
Epoch 100/200, Train Loss: 1.3093, Val Loss: 1.3393
Epoch 120/200, Train Loss: 1.2366, Val Loss: 1.2913
Epoch 140/200, Train Loss: 1.1805, Val Loss: 1.2586
Epoch 160/200, Train Loss: 1.1348, Val Loss: 1.2368
Epoch 180/200, Train Loss: 1.0949, Val Loss: 1.2201
Epoch 200/200, Train Loss: 1.0589, Val Loss: 1.2079
carlo_data_cellno170.mat
1 (108, 108, 3160)
Epoch 1/200, Train Loss: 2.0321, Val Loss: 2.5623
Epoch 20/200, Train Loss: 0.6661, Val Loss: 0.8163
Epoch 40/200, Train Loss: 0.2682, Val Loss: 0.3250
Epoch 60/200, Train Loss: 0.2039, Val Loss: 0.2578
Epoch 80/200, Train Loss: 0.1840, Val Loss: 0.2432
Epoch 100/200, Train Loss: 0.1776, Val Loss: 0.2361
Epoch 120/200, Train Loss: 0.1752, Val Lo

carlo_data_cellno86.mat
14 (108, 108, 3160)
Epoch 1/200, Train Loss: 7.3573, Val Loss: 5.9294
Epoch 20/200, Train Loss: 5.1816, Val Loss: 3.9620
Epoch 40/200, Train Loss: 3.4606, Val Loss: 2.5411
Epoch 60/200, Train Loss: 2.6407, Val Loss: 1.9536
Epoch 80/200, Train Loss: 2.1385, Val Loss: 1.5341
Epoch 100/200, Train Loss: 1.7385, Val Loss: 1.2236
Epoch 120/200, Train Loss: 1.4171, Val Loss: 1.0000
Epoch 140/200, Train Loss: 1.1681, Val Loss: 0.8423
Epoch 160/200, Train Loss: 0.9716, Val Loss: 0.7246
Epoch 180/200, Train Loss: 0.8187, Val Loss: 0.6374
Epoch 200/200, Train Loss: 0.6961, Val Loss: 0.5670
carlo_data_cellno133.mat
15 (108, 108, 3160)
Epoch 1/200, Train Loss: 2.6757, Val Loss: 3.4193
Epoch 20/200, Train Loss: 1.4310, Val Loss: 1.8097
Epoch 40/200, Train Loss: 0.7888, Val Loss: 0.9128
Epoch 60/200, Train Loss: 0.4587, Val Loss: 0.4906
Epoch 80/200, Train Loss: 0.2748, Val Loss: 0.2997
Epoch 100/200, Train Loss: 0.2099, Val Loss: 0.2407
Epoch 120/200, Train Loss: 0.1829, Val 