### Direct Smoothing

In [2]:
# Path
import os, sys
os.chdir('/home/zwang34/IBL/iblfm_exp/IBL_foundation_model')
sys.path.append('./src')
print(sys.path)

import logging
logging.getLogger().setLevel(logging.ERROR)

# Lib
from datasets import load_dataset, concatenate_datasets
import numpy as np
from loader.make_loader import make_loader
from utils.eval_utils import bits_per_spike
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

['/home/zwang34/miniconda3/envs/ibl-fm/lib/python310.zip', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/lib-dynload', '', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/site-packages', './src']


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Fix Args
EID = '671c7ea7-6726-4fbe-adeb-f89c2c8e489b'
kernel_sigma_list = [1, 2, 3, 8, 16, 32]
fr_filter = 0.1 # unit: 1 Hz

In [10]:
## Prepare data
dataset = load_dataset(f'neurofm123/{EID}_aligned', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache')
train_dataset = dataset['train']
val_dataset = dataset['val']
test_dataset = dataset['test']
n_neurons = len(test_dataset[0]['cluster_uuids'])

train_dataloader = make_loader(
    train_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=True,
)

val_dataloader = make_loader(
    val_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=False,
)

test_dataloader = make_loader(
    test_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=False,
)

for batch in train_dataloader:
    train_data = batch['spikes_data'].detach().cpu().numpy()
for batch in val_dataloader:
    val_data = batch['spikes_data'].detach().cpu().numpy()
for batch in test_dataloader:
    test_data = batch['spikes_data'].detach().cpu().numpy()

# Use a fr Filter
whole_data = np.concatenate([train_data, val_data, test_data], axis=0)
mean_fr = np.mean(whole_data, axis=(0,1)) * 50   # hz
valid_idx = (mean_fr >= fr_filter)

train_data = train_data[:, :, valid_idx]
val_data = val_data[:, :, valid_idx]
test_data = test_data[:, :, valid_idx]
print(f'valid neuron: {sum(valid_idx)}, invalid neuron: {len(valid_idx)-sum(valid_idx)}')

for k, kernel_sigma in enumerate(kernel_sigma_list):       
    gt_spikes = test_data
    smoothed_spikes = gaussian_filter1d(gt_spikes, sigma=kernel_sigma, axis=1)
    fr_stat = np.log(np.mean(gt_spikes, axis=(0, 1))+1e-9)
    
    bps_smth_list = []
    bps_gt_list = []
    for i in range(sum(valid_idx)):
        bps_smth_list.append(bits_per_spike(smoothed_spikes[:, :, [i]], gt_spikes[:, :, [i]]))
        bps_gt_list.append(bits_per_spike(gt_spikes[:, :, [i]], gt_spikes[:, :, [i]]))
    population_smth_bps = bits_per_spike(smoothed_spikes, gt_spikes)
    population_upb_bps = bits_per_spike(gt_spikes, gt_spikes)

    mean_smth_bps = np.nanmean(bps_smth_list)
    mean_upb_bps = np.nanmean(bps_gt_list)
    print(f'{kernel_sigma*20}ms smoothing bps: {mean_smth_bps}, gt bps: {mean_upb_bps}. Population bps: {population_smth_bps}, {population_upb_bps}')

    
    '''
    scatter = axes.flat[k].scatter(bps_smth_list, bps_gt_list, c=fr_stat)
    axes.flat[k].set_title(f'kernel sigma={kernel_sigma*20}ms')
    axes.flat[k].set_xlim([-10, 15])
    axes.flat[k].set_ylim([-10, 15])
    axes.flat[k].plot([-10, 15], [-10, 15], c='grey')
    '''

'''
axes.flat[0].set_ylabel('gt vs. gt bps')
axes.flat[3].set_ylabel('gt vs. gt bps')
axes.flat[3].set_xlabel('gt vs. smoothed bps')
axes.flat[4].set_xlabel('gt vs. smoothed bps')
axes.flat[5].set_xlabel('gt vs. smoothed bps')
fig.colorbar(scatter)
plt.savefig('smoothing_direct.png')
'''

len(dataset): 559
len(dataset): 80
len(dataset): 160
valid neuron: 540, invalid neuron: 128
20ms smoothing bps: 3.720978651248588, gt bps: 4.6170632896923385. Population bps: 1.7530588882369578, 2.280131022648492
40ms smoothing bps: 3.131915272451578, gt bps: 4.608511572732067. Population bps: 1.4496777084966443, 2.280131022648492
60ms smoothing bps: 2.835967394774532, gt bps: 4.608511572732067. Population bps: 1.3107413520144735, 2.280131022648492
160ms smoothing bps: 2.241143864721745, gt bps: 4.608511572732067. Population bps: 1.0649939663353993, 2.280131022648492
320ms smoothing bps: 1.8848178554006771, gt bps: 4.608511572732067. Population bps: 0.9141381802525307, 2.280131022648492
640ms smoothing bps: 1.5774835660577733, gt bps: 4.608511572732067. Population bps: 0.7626687600713836, 2.280131022648492


"\naxes.flat[0].set_ylabel('gt vs. gt bps')\naxes.flat[3].set_ylabel('gt vs. gt bps')\naxes.flat[3].set_xlabel('gt vs. smoothed bps')\naxes.flat[4].set_xlabel('gt vs. smoothed bps')\naxes.flat[5].set_xlabel('gt vs. smoothed bps')\nfig.colorbar(scatter)\nplt.savefig('smoothing_direct.png')\n"

### Smoothing + Poisson GLM

In [8]:
# Path
import os, sys
os.chdir('/home/zwang34/IBL/iblfm_exp/IBL_foundation_model')
sys.path.append('./src')
print(sys.path)

import logging
logging.getLogger().setLevel(logging.ERROR)

# Lib
from datasets import load_dataset, concatenate_datasets
import numpy as np
from loader.make_loader import make_loader
from utils.eval_utils import bits_per_spike
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d

['/home/zwang34/miniconda3/envs/ibl-fm/lib/python310.zip', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/lib-dynload', '', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/site-packages', './src', '/scratch/zwang34/job_33291653/tmpkupvmri4', './src', './src']


In [20]:
## Fix Args
EID = '03d9a098-07bf-4765-88b7-85f8d8f620cc'
randomized = False
kernel_sigma = 2  # unit: 20 ms
heldout_ratio = 0.1 
fr_filter = 0.1  # unit: 1 Hz

seed = 28

In [21]:
np.random.seed(seed)

## Prepare data
if randomized == True:
    dataset = load_dataset(f'neurofm123/{EID}_aligned', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache')
    train_dataset = dataset['train']
    val_dataset = dataset['val']
    test_dataset = dataset['test']
    n_neurons = len(test_dataset[0]['cluster_uuids'])
else:
    dataset = load_dataset(f'neurofm123/{EID}_nonrandomized', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache')
    train_dataset = dataset['train']
    val_dataset = dataset['val']
    test_dataset = dataset['test']
    n_neurons = len(test_dataset[0]['cluster_uuids'])

train_dataloader = make_loader(
    train_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=True,
)

val_dataloader = make_loader(
    val_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=False,
)

test_dataloader = make_loader(
    test_dataset,
    target=None,
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=False,
)

for batch in train_dataloader:
    train_data = batch['spikes_data'].detach().cpu().numpy()
for batch in val_dataloader:
    val_data = batch['spikes_data'].detach().cpu().numpy()
for batch in test_dataloader:
    test_data = batch['spikes_data'].detach().cpu().numpy()

# Use a fr Filter
whole_data = np.concatenate([train_data, val_data, test_data], axis=0)
mean_fr = np.mean(whole_data, axis=(0,1)) * 50   # hz
valid_idx = (mean_fr >= fr_filter)
valid_idx_number = np.where(np.array(valid_idx)==1)[0]
train_data = train_data[:, :, valid_idx]
val_data = val_data[:, :, valid_idx]
test_data = test_data[:, :, valid_idx]

# Randomly select heldout neurons
n_valid_neurons = train_data.shape[-1]
heldout_idxs = np.random.choice(range(n_valid_neurons), size=int(n_valid_neurons*heldout_ratio), replace=False)
heldout_idxs_raw = valid_idx_number[heldout_idxs]
mask = np.ones(n_valid_neurons, dtype=bool)
mask[heldout_idxs] = False

print(f'held-out neurons : held-in neurons = {heldout_idxs.shape[0]} : {n_valid_neurons-heldout_idxs.shape[0]}')
print(f'held-out idxs: {heldout_idxs_raw}')

train_spikes_heldin = train_data[:, :, mask]
train_spikes_heldout = train_data[:, :, ~mask]
test_spikes_heldin = test_data[:, :, mask]
test_spikes_heldout = test_data[:, :, ~mask]
val_spikes_heldin = val_data[:, :, mask]
val_spikes_heldout = val_data[:, :, ~mask]

len(dataset): 397
len(dataset): 57
len(dataset): 114
held-out neurons : held-in neurons = 41 : 374
held-out idxs: [152 521 300 424 190 354  69  82 385 314 487 114 381 303  81  89 128 501
  65 132 442 155 313  50 459 230 200 326 367 426   8 512 333 411 444 264
 506 161 236 499  58]


In [22]:
## Adapt. We use Val or Test split as the eval set here.
eval_spikes_heldin = test_spikes_heldin
eval_spikes_heldout = test_spikes_heldout

In [23]:
## Copied from NLB'21 repo
## Define helper function for training Poisson regressor

from sklearn.linear_model import PoissonRegressor

def fit_poisson(train_input, eval_input, train_output, alpha=0.0):
    train_pred = []
    eval_pred = []
    # train Poisson GLM for each output column
    for chan in range(train_output.shape[1]):
        pr = PoissonRegressor(alpha=alpha, max_iter=500)
        pr.fit(train_input, train_output[:, chan])
        train_pred.append(pr.predict(train_input))
        eval_pred.append(pr.predict(eval_input))
    train_pred = np.vstack(train_pred).T
    eval_pred = np.vstack(eval_pred).T
    return train_pred, eval_pred

## Smooth spikes

# Assign useful variables
tlength = train_spikes_heldin.shape[1]
num_train = train_spikes_heldin.shape[0]
num_eval = eval_spikes_heldin.shape[0]
num_heldin = train_spikes_heldin.shape[2]
num_heldout = train_spikes_heldout.shape[2]

'''
# Smooth spikes with 40 ms std gaussian
import scipy.signal as signal
kern_sd_ms = 20
kern_sd = int(round(kern_sd_ms / 20))
window = signal.gaussian(kern_sd * 6, kern_sd, sym=True)
window /= np.sum(window)
filt = lambda x: np.convolve(x, window, 'same')

train_spksmth_heldin = np.apply_along_axis(filt, 1, train_spikes_heldin)
eval_spksmth_heldin = np.apply_along_axis(filt, 1, eval_spikes_heldin)
'''

# use more convenient smoothing function
train_spksmth_heldin = gaussian_filter1d(train_spikes_heldin, kernel_sigma, axis=1)
eval_spksmth_heldin = gaussian_filter1d(eval_spikes_heldin, kernel_sigma, axis=1)

## Generate rate predictions

# Reshape data to 2d for regression
train_spksmth_heldin_s = train_spksmth_heldin.reshape(-1, train_spksmth_heldin.shape[2])
eval_spksmth_heldin_s = eval_spksmth_heldin.reshape(-1, eval_spksmth_heldin.shape[2])
train_spikes_heldout_s = train_spikes_heldout.reshape(-1, train_spikes_heldout.shape[2])

# Train Poisson regressor from log of held-in smoothed spikes to held-out spikes
train_spksmth_heldout_s, eval_spksmth_heldout_s = fit_poisson(
    np.log(train_spksmth_heldin_s + 1e-4), # add constant offset to prevent taking log of 0
    np.log(eval_spksmth_heldin_s + 1e-4),
    train_spikes_heldout_s,
    alpha=0.1,
)

# Reshape data back to the same 3d shape as the input arrays
train_rates_heldin = train_spksmth_heldin_s.reshape((num_train, tlength, num_heldin))
train_rates_heldout = train_spksmth_heldout_s.reshape((num_train, tlength, num_heldout))
eval_rates_heldin = eval_spksmth_heldin_s.reshape((num_eval, tlength, num_heldin))
eval_rates_heldout = eval_spksmth_heldout_s.reshape((num_eval, tlength, num_heldout))

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)
STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
  self.n_iter_ = _check_optimize_result("lbfgs", opt_res)


#### Note
- held-in neuron rates in both train and eval split are obtained by directly smoothing.
- held-out neuron rates in both train and eval split are obtained by GLM regression.

In [24]:
## Population Bps
train_bps = bits_per_spike(train_rates_heldout, train_spikes_heldout)
eval_bps = bits_per_spike(eval_rates_heldout, eval_spikes_heldout)
## Neuron Average Bps
train_bps_list = []
eval_bps_list = []
for i in range(train_rates_heldout.shape[-1]):
    train_bps_list.append(bits_per_spike(train_rates_heldout[:, :, [i]], train_spikes_heldout[:, :, [i]]))
    eval_bps_list.append(bits_per_spike(eval_rates_heldout[:, :, [i]], eval_spikes_heldout[:, :, [i]]))

print(f"(Population) train_heldout_bps: {train_bps}, eval_heldout_bps: {eval_bps}")
print(f'(Neuron average) train_heldout_bps: {np.mean(train_bps_list)}, eval_heldout_bps: {np.mean(eval_bps_list)}')

(Population) train_heldout_bps: 0.3650479659630438, eval_heldout_bps: 0.22753139228508354
(Neuron average) train_heldout_bps: 0.8325431212139421, eval_heldout_bps: 0.42697066213162116
