In [1]:
# Path
import os, sys
os.chdir('/home/zwang34/IBL/iblfm_exp/IBL_foundation_model')
sys.path.append('./src')
print(sys.path)

['/home/zwang34/miniconda3/envs/ibl-fm/lib/python310.zip', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/lib-dynload', '', '/home/zwang34/miniconda3/envs/ibl-fm/lib/python3.10/site-packages', './src']


In [2]:
# Lib
from datasets import load_dataset, concatenate_datasets
import numpy as np
from accelerate import Accelerator
from loader.make_loader import make_loader
from utils.eval_utils import bits_per_spike
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter1d
import argparse
import torch
import torch.nn as nn
from src.utils.utils import move_batch_to_device, metrics_list, plot_gt_pred, plot_neurons_r2
import wandb
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [44]:
# Args
non_randomized = False
unaligned = False
eid = '746d1902-fa59-4cab-b0aa-013be36060d5'
sigma = 200  # For smoothing
seed = 12  # For the random signal

In [None]:
# Dataset and dataloader
if unaligned:
    _al = load_dataset(f'neurofm123/{eid}_aligned', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache', download_mode='force_redownload')
    _ual = load_dataset(f'neurofm123/{eid}', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache', download_mode='force_redownload')
    dataset = split_unaligned_dataset(_al, _ual)
    train_dataset = dataset["train"]
    val_dataset = dataset["val"]
    test_dataset = dataset["test"]    
elif non_randomized:
    dataset = load_dataset(f'neurofm123/{eid}_nonrandomized', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache', download_mode='force_redownload')
    train_dataset = dataset['train']
    val_dataset = dataset['val']
    test_dataset = dataset['test']
else:
    dataset = load_dataset(f'neurofm123/{eid}_aligned', cache_dir='/expanse/lustre/scratch/zwang34/temp_project/iTransformer/checkpoints/datasets_cache', download_mode='force_redownload')
    train_dataset = dataset['train']
    val_dataset = dataset['val']
    test_dataset = dataset['test']

try:
    bin_size = train_dataset["binsize"][0]
except:
    bin_size = train_dataset["bin_size"][0]

print(train_dataset.column_names)
whole_dataset = concatenate_datasets([train_dataset, val_dataset, test_dataset])
max_time = int(max(whole_dataset['start_times']))

n_neurons = len(train_dataset[0]['cluster_uuids'])
print(f'n neurons: {n_neurons}')

train_dataloader = make_loader(
    train_dataset,
    target='start_times_raw',
    load_meta=True,
    batch_size=16,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=True,
)

val_dataloader = make_loader(
    val_dataset,
    target='start_times_raw',
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
)

test_dataloader = make_loader(
    test_dataset,
    target='start_times_raw',
    load_meta=True,
    batch_size=10000,
    pad_to_right=True,
    pad_value=-1,
    bin_size=0.02,
    max_time_length=100,
    max_space_length=n_neurons,
    dataset_name='ibl',
    shuffle=False,
)

Downloading readme: 100%|██████████| 2.54k/2.54k [00:00<00:00, 12.7MB/s]
Downloading data: 100%|██████████| 17.4M/17.4M [00:00<00:00, 71.8MB/s]
Downloading data: 100%|██████████| 2.74M/2.74M [00:00<00:00, 25.7MB/s]
Downloading data: 100%|██████████| 5.19M/5.19M [00:00<00:00, 42.6MB/s]
Generating train split: 100%|██████████| 453/453 [00:00<00:00, 640.77 examples/s]
Generating val split: 100%|██████████| 65/65 [00:00<00:00, 590.83 examples/s]
Generating test split: 100%|██████████| 130/130 [00:00<00:00, 657.29 examples/s]


['spikes_sparse_data', 'spikes_sparse_indices', 'spikes_sparse_indptr', 'spikes_sparse_shape', 'choice', 'reward', 'block', 'whisker-motion-energy', 'binsize', 'interval_len', 'eid', 'sampling_freq', 'cluster_regions', 'cluster_channels', 'cluster_depths', 'good_clusters', 'cluster_uuids', 'cluster_qc', 'start_times', 'end_times']
n neurons: 1337
len(dataset): 453


In [None]:
# Create a random smooth signal
import numpy as np
from scipy.ndimage import gaussian_filter1d

def generate_smooth_random_time_series(time_array, sigma=500, seed=None):
    if seed is not None:
        np.random.seed(seed)
    
    random_values = np.random.rand(time_array.shape[0])

    smooth_values = gaussian_filter1d(random_values, sigma=sigma)
    smooth_stand_values = (smooth_values - smooth_values.mean()) / smooth_values.std()

    return smooth_stand_values


time_array = np.linspace(0, 10, max_time+2)
smooth_random_series = generate_smooth_random_time_series(time_array, sigma=sigma, seed=seed)


plt.plot(smooth_random_series)

In [None]:
# MLP model
## Model and training Args
input_size = 100 * n_neurons
hs1 = 128
hs2 = 256
output_size = 1

lr = 1e-4
wd = 1
eps = 1e-8
epochs = 5

class MLP(nn.Module):
    def __init__(self, input_size, hidden_size1, hidden_size2, output_size, dropout_rate=0):
        super(MLP, self).__init__()
        
        self.fc1 = nn.Linear(input_size, hidden_size1)
        self.fc2 = nn.Linear(hidden_size1, hidden_size2)
        self.fc3 = nn.Linear(hidden_size2, output_size)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=dropout_rate)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

model = MLP(input_size, hs1, hs2, output_size)
accelerator = Accelerator()
model = accelerator.prepare(model)
print(model)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, eps=eps)
loss_fn = nn.MSELoss(reduction='none')

In [None]:
## Training
best_eval_loss = np.inf
best_eval_epoch = 0

train_loss_list = []
eval_loss_list = []

for epoch in tqdm(range(epochs)):
    # train epoch
    model.train()
    train_loss = 0
    train_examples = 0
    for batch in train_dataloader:
        batch = move_batch_to_device(batch, accelerator.device)
        spikes_flat = batch['spikes_data'].reshape(batch['spikes_data'].shape[0], -1)
        preds = model(spikes_flat)
        tgt = torch.tensor(smooth_random_series[batch['target'].detach().cpu().numpy().astype(np.int32)], device=preds.device, dtype=torch.float32)
        loss = loss_fn(preds, tgt.unsqueeze(1)).sum()

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        train_loss += loss.item()
        train_examples += tgt.shape[0]

    train_loss /= train_examples
    print(f"Epoch {epoch} training loss: {train_loss}")
    train_loss_list.append(train_loss)

    model.eval()
    eval_loss = 0
    eval_examples = 0
    gt_list, pred_list = [], []
    with torch.no_grad():
        for batch in val_dataloader:
            batch = move_batch_to_device(batch, accelerator.device)
            spikes_flat = batch['spikes_data'].reshape(batch['spikes_data'].shape[0], -1)
            preds = model(spikes_flat)
            tgt = torch.tensor(smooth_random_series[batch['target'].detach().cpu().numpy().astype(np.int32)], device=preds.device, dtype=torch.float32)
            loss = loss_fn(preds, tgt.unsqueeze(1)).sum()

            eval_loss += loss.item()
            eval_examples += tgt.shape[0]
            pred_list.append(preds)
            gt_list.append(tgt)
            
    gt = torch.cat(gt_list, dim=0).detach().cpu().numpy()
    preds = torch.cat(pred_list, dim=0).detach().cpu().numpy()

    gt = [x for t, x in sorted(zip(batch['target'].detach().cpu().numpy().astype(np.int32), gt))]
    preds = [x for t, x in sorted(zip(batch['target'].detach().cpu().numpy().astype(np.int32), preds))]

    fig = plt.figure()
    plt.plot(gt)
    plt.plot(preds)
    
    
    eval_loss /= eval_examples
    if eval_loss < best_eval_loss:
        best_eval_loss = eval_loss
        best_eval_epoch = epoch
        gt_best = gt
        pred_best = preds
        
    print(f"Epoch {epoch} eval loss: {eval_loss}")
    eval_loss_list.append(eval_loss)

fig = plt.figure()
plt.plot(gt_best)
plt.plot(pred_best)