# Bayes error & ssi for natural images

purpose: computes each cell's bayesian decoding error and ssi for natural images.


execution time: 4 min per cell

## Setup 

activate fisher_info_limits2

```python
conda activate envs/fisher_info_limits2
python -m ipykernel install --user --name fisher_info_limits2 --display-name "fisher_info_limits2"
```


In [8]:
import numpy as np
import math
import scipy.io as sio
import matplotlib.pyplot as  plt
import matplotlib.gridspec as gridspec
from matplotlib.patches import Ellipse
from scipy.stats import wilcoxon, ks_2samp, multivariate_normal
from scipy.interpolate import griddata
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score
from scipy.stats import pearsonr, spearmanr
from numba import njit, prange
import os
import re

# setup project path
proj_path = "/home/steeve/steeve/idv/code/fisher-info-limits/"

# setup reproducibility 
np.random.seed(10)

# setup paths
CELLS_PATH = os.path.join(proj_path, 'data/contrast_cells/') # path containing cells .mat files (e.g., carlo_data_cellno2.mat? (205 MB))
DATA_PATH = os.path.join(proj_path, 'data/computed_contrast_cells/') # path containing cells .mat files (e.g., carlo_data_cellno2.mat? (205 MB))

In [4]:
def scatter_hist(x, y, ax, ax_histx, ax_histy, c = 'tab:blue', alpha = 1, htype = 'step'):
    # no labels
    ax_histx.tick_params(axis="x", labelbottom=False)
    ax_histy.tick_params(axis="y", labelleft=False)

    # the scatter plot:
    ax.scatter(x, y,s=1, color=c)

    # now determine nice limits by hand:
    binwidth = 0.2
    xymax = max(np.max(np.abs(x)), np.max(np.abs(y)))
    print(xymax)
    lim = (np.rint((xymax+1)/binwidth) + 1) * binwidth

    bins = np.arange(-lim, lim + binwidth, binwidth)
    axx = ax_histx.hist(x, bins=bins)
    ax_histx.set_ylim(0,(axx[0].max()//100+1)*100)
    axy = ax_histy.hist(y, bins=bins, orientation='horizontal')
    ax_histy.set_xlim(0,(axy[0].max()//100+1)*100)


def plot_gaussian_ellipse(mean, cov, ax, n_std=1.0, **kwargs):
    vals, vecs = np.linalg.eigh(cov)
    order = vals.argsort()[::-1]
    vals = vals[order]
    vecs = vecs[:, order]

    theta = np.degrees(np.arctan2(*vecs[:,0][::-1]))
    width, height = 2 * n_std * np.sqrt(vals)
    
    ellipse = Ellipse(xy=mean, width=width, height=height, angle=theta, **kwargs)
    ax.add_patch(ellipse)


def SUM_LOG_LIST(position):
    '''Given an integer n it recursively calculates log(n!)'''
    if position == 0:
        return np.array([0])
    if position == 1:
        return np.append(SUM_LOG_LIST(0), 0)
    new_list = SUM_LOG_LIST(position-1)
    return np.append(new_list, new_list[-1]+np.around(np.log(float(position)), 8))


def POISSON_2DCELL(tc_grid, max_firing=20):
    log_list = np.tile(SUM_LOG_LIST(max_firing)[:,None,None], tc_grid.shape)
    log_tc = np.around(np.log(tc_grid), 8)#, where=(mask==1), out = np.ones_like(tc_grid)*-100)
    log_likelihood = (np.array([(i*log_tc-tc_grid) for i in range(max_firing+1)])-log_list)
#     log_likelihood[0, mask==0]=0
#     log_likelihood[1:, mask==0]=-np.inf
    likelihood = np.exp(log_likelihood)
    likelihood = likelihood/np.sum(likelihood, axis=0)
    return likelihood    

    
# Define the network
class SimpleRegressor(nn.Module):
    def __init__(self):
        super(SimpleRegressor, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(2, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        return self.model(x)

In [None]:
%%time

# loop over the cells contained in the data path
for i,input_filename in enumerate(os.listdir(CELLS_PATH)):
    
    if '.mat' not in input_filename:
        continue
    
    # print the name of the cell data file
    print(input_filename)

    # ensure it contains cell number "no"
    match = re.search(r"no\d+", input_filename)
    cell_no = match.group()

    # setup the data filename to save 
    output_filename = f'BDEvSSI_{cell_no}.npz'

    # load the cell data
    mat = sio.loadmat(os.path.join(CELLS_PATH, input_filename))
    #bases = mat['U']
    imgs = mat['X']     # sl - image
    pcs = mat['X_lowd'] # sl - all principal components
    fit = mat['f']      # sl - cell tuning curve (average response or receptive field?)
    # spikes = mat['r']   # sl - spikes
    print(i, imgs.shape)
    #     plt.matshow(imgs)

    # calculate average and variance of 
    # principal component 1 and 2
    mu0 = pcs[0].mean()
    mu1 = pcs[1].mean()
    sigma = np.cov(pcs[0], pcs[1])

    X = np.copy(pcs[:2]).T          # x and y
    y = np.copy(np.array(fit[0]))   # r
    
    # normalize the inputs for better training 
    scaler = StandardScaler()   
    X_scaled = scaler.fit_transform(X)

    # create cross-validation splits
    X_train_np, X_test_np, y_train_np, y_test_np = train_test_split(
        X_scaled, y, test_size=0.2, random_state=42)

    # convert to PyTorch tensors
    X_train = torch.tensor(X_train_np, dtype=torch.float32, requires_grad=True)
    X_test  = torch.tensor(X_test_np, dtype=torch.float32, requires_grad=True)
    y_train = torch.tensor(y_train_np, dtype=torch.float32).view(-1, 1)
    y_test  = torch.tensor(y_test_np, dtype=torch.float32).view(-1, 1)

    # initialize model, loss, and optimizer
    model = SimpleRegressor()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    n_epochs = 200
    for epoch in range(n_epochs):

        model.train()
        optimizer.zero_grad()
        y_pred = model(X_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        optimizer.step()

        if (epoch+1) % 20 == 0 or epoch == 0:
            model.eval()
            val_loss = criterion(model(X_test), y_test).item()
            print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}")

    # Inference: evaluate on a grid for visualization
    model.eval()

    #--------------------------------------------------------------------------------------
    # CALCULATE GRID - THIS GRID WILL BE USED THROUGHOUT

    grid_x, grid_y = np.meshgrid(np.linspace(-3, 3, 301), np.linspace(-3 ,3 , 301))
    grid_input = np.column_stack((grid_x.ravel(), grid_y.ravel()))
    grid_input_scaled = scaler.transform(grid_input)
    #--------------------------------------------------------------------------------------

    mus = pcs[:2].mean(axis=1)
    sigma = np.cov(pcs[:2])

    prior = multivariate_normal(mus, sigma)

    # grid = np.stack([grid_x.ravel(), grid_y.ravel()], axis=-1)
    grid_prior = prior.pdf(grid_input).reshape(grid_x.shape)/prior.pdf(grid_input).reshape(grid_x.shape).sum()

    with torch.no_grad():
        preds = model(torch.tensor(grid_input, dtype=torch.float32)).numpy()[:,0]

    baseline = 1e-2
    tc = preds.reshape(grid_x.shape)+baseline
    likelihood = POISSON_2DCELL(tc)
    evidence = np.sum(likelihood*np.tile(grid_prior[None,:,:], (likelihood.shape[0],1,1)), axis=(-1,-2))
    posterior = likelihood*np.tile(grid_prior[None,:,:], (likelihood.shape[0],1,1))/\
                            np.tile(evidence[:,None,None], (1,*likelihood.shape[1:])) 
    posterior[posterior==0] = 1e-50
    posterior = posterior/np.tile(posterior.sum(axis=(1,2))[:,None,None], (1, *grid_x.shape))

    post_entropy = -np.sum(posterior*np.around(np.log2(posterior),8), axis=(1,2))
    prior_entropy = -np.sum(grid_prior*np.around(np.log2(grid_prior), 8))
    ssi = prior_entropy - np.sum(likelihood * np.tile(post_entropy[:,None,None], (1, *grid_x.shape)) ,axis=0)

    # sl - compute mean squared error
    mse = np.empty_like(posterior[0])
    for i in range(likelihood.shape[1]):
        for j in range(likelihood.shape[2]):
            print(i,j, end = '\r')    
    #         post_ij = np.sum(prior*np.tile(likelihood[:,i,j][:,None,None], (1, *likelihood.shape[1:])), axis=0)
            post_ij = np.sum(posterior*np.tile(likelihood[:,i,j][:,None,None], (1,*likelihood.shape[1:])), axis=0)
            delta = (grid_x - grid_x[i,j])**2+(grid_y-grid_y[i,j])**2
            mse[i,j] = np.sum(post_ij*delta)
    rmse = np.sqrt(mse)

    # save computed data
    np.savez(os.path.join(DATA_PATH, output_filename), mse=mse, posterior=posterior, ssi = ssi, grid_x=grid_x, grid_y=grid_y)

carlo_data_cellno2.mat
0 (108, 108, 3160)
Epoch 1/200, Train Loss: 0.1037, Val Loss: 0.1473
Epoch 20/200, Train Loss: 0.0473, Val Loss: 0.0904
Epoch 40/200, Train Loss: 0.0356, Val Loss: 0.0761
Epoch 60/200, Train Loss: 0.0305, Val Loss: 0.0712
Epoch 80/200, Train Loss: 0.0276, Val Loss: 0.0673
Epoch 100/200, Train Loss: 0.0259, Val Loss: 0.0644
Epoch 120/200, Train Loss: 0.0248, Val Loss: 0.0627
Epoch 140/200, Train Loss: 0.0240, Val Loss: 0.0612
Epoch 160/200, Train Loss: 0.0234, Val Loss: 0.0605
Epoch 180/200, Train Loss: 0.0229, Val Loss: 0.0599
Epoch 200/200, Train Loss: 0.0225, Val Loss: 0.0592
300 300