In [None]:
# Set MPS fallback to enable operations not supported natively on Apple Silicon
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

import bopt
import h5py as h5
import utils

In [2]:
sample_rate = 50
device = bopt.cuda_init()  # Using specific GPU
# device='cuda:7'
samplerate = 50
num_before = 25
num_after = 5
seed = 2222
torch.random.manual_seed(seed)

--- device:0 has 1
--- device:1 has 1
--- device:2 has 3
--- device:3 has 1
--- device:4 has 1
--- device:5 has 1
--- device:6 has 0
--- device:7 has 0
Selected device: cuda:6


<torch._C.Generator at 0x7f90e8102d50>

In [3]:
# Let's first load the h5 and see what's inside
h5_filepath = '/home/sunnyliu1220/git/latent-gaze/data/charm_50_rec_reduced.h5'
charmander_clusters= [ 4, 15, 41, 42, 43, 50, 62, 107, 121, 168, 225, 226, 245, 251, 259, 261,
    263, 271, 282, 294, 302, 327, 334, 340, 342, 347, 363, 364, 367, 375, 400,
    555
]

In [4]:
# Let's verify the contents of our newly created H5 file
with h5.File(h5_filepath, 'r') as f:
    # Print the high-level structure
    print("Keys in the reduced h5 file:", list(f.keys()))
    
    # Check data structure
    print("\nData structure:")
    for series_key in f['data'].keys():
        print(f"  Series: {series_key}")
        for epoch_key in f['data'][series_key].keys():
            print(f"    Epoch: {epoch_key}")
            for group_key in f['data'][series_key][epoch_key].keys():
                print(f"      Group: {group_key}")
    
    # Check metadata
    print("\nMetadata structure:")
    for meta_key in f['meta'].keys():
        print(f"  {meta_key}")
    
    # Check file size
    import os
    print(f"\nFile size: {os.path.getsize(h5_filepath) / (1024*1024):.2f} MB")

Keys in the reduced h5 file: ['data', 'meta']

Data structure:
  Series: series_008
    Epoch: epoch_001
      Group: firing_rates
      Group: signals
      Group: stimulus
  Series: series_009
    Epoch: epoch_001
      Group: firing_rates
      Group: signals
      Group: stimulus

Metadata structure:
  cluster_ids
  reconstruction

File size: 636.14 MB


In [5]:
direction='shifted'
# test_series = ['series_008/epoch_001', 'series_009/epoch_001']
test_series = ['series_008/epoch_001']
test_idxs = [-samplerate * 10, -1]
test_all = [0, -1]

test_dataset_shifted = bopt.CorticalDataset(h5_filepath,
                                    test_series,
                                    num_before=num_before,
                                    num_after=num_after,
                                    start_idx=test_all[0],
                                    end_idx=test_all[1],
                                    stimulus_key='shifted',
                                    grayscale=True,
                                    normalize_signals=False,
                                    signals=['locomotion', 'azimuth'],
                                    which_clusters=charmander_clusters,
                                    zero_blinks=True)

test_loader_shifted = torch.utils.data.DataLoader(test_dataset_shifted,
                                          batch_size=256,
                                          shuffle=False)



Zeroing out blinks in stimulus (at init).


In [6]:
model_path = '/home/sunnyliu1220/git/latent-gaze/models/final_model.pt'
# Load the model
model = torch.load(model_path, map_location=device)
model.eval()

  model = torch.load(model_path, map_location=device)


CNNComponent(
  (layers): ModuleDict(
    (conv0): Conv2d(30, 24, kernel_size=(7, 7), stride=(1, 1), padding=valid)
    (layernorm0): LayerNorm((24, 62, 96), eps=1e-05, elementwise_affine=False)
    (dropout0): Dropout(p=0.1, inplace=False)
    (nl0): Softplus(beta=1.0, threshold=20.0)
    (conv1): Conv2d(24, 24, kernel_size=(7, 7), stride=(1, 1), padding=valid)
    (layernorm1): LayerNorm((24, 56, 90), eps=1e-05, elementwise_affine=False)
    (dropout1): Dropout(p=0.1, inplace=False)
    (nl1): Softplus(beta=1.0, threshold=20.0)
    (conv2): Conv2d(24, 24, kernel_size=(7, 7), stride=(1, 1), padding=valid)
    (layernorm2): LayerNorm((24, 50, 84), eps=1e-05, elementwise_affine=False)
    (dropout2): Dropout(p=0.1, inplace=False)
    (nl2): Softplus(beta=1.0, threshold=20.0)
    (conv3): Conv2d(24, 24, kernel_size=(7, 7), stride=(1, 1), padding=valid)
    (layernorm3): LayerNorm((24, 44, 78), eps=1e-05, elementwise_affine=False)
    (dropout3): Dropout(p=0.1, inplace=False)
    (nl3): S

In [None]:
def model_log_lkhd(x, y, z_grid_masked, model, device=device):
    """
    Compute the log likelihood of the model given the input data.
    Assume the eye position doens't change during the window.

    Parameters:
    x (torch.Tensor): Originally shifted stimulus. Shape (T, H, W).
    y (torch.Tensor): Ground truth neural activity. Shape (N).
    z_grid_masked (torch.Tensor): Masked grid of latent eye position. Shape (M, 2).
    model (torch.nn.Module): The trained model.
    device (torch.device): The device to run the model on.
    """
    with torch.no_grad():
        x = x.to(device)
        y = y.to(device)
        z_grid_masked = z_grid_masked.to(device)
        # Let's shift the stimulus first
        x_shifted = utils.shift_stimulus(x, z_grid_masked[0], z_grid_masked[1]) # (T, M, H, W)
        x_shifted = x_shifted.transpose(0, 1) # (M, T, H, W)
        # Now we can pass the shifted stimulus through the model
        y_pred = model(x_shifted) # (M, N)
        # Compute the log likelihood assuming Gaussian noise
        log_likelihood = -0.5 * torch.sum((y_pred - y.unsqueeze(0)) ** 2, dim=1)
        return log_likelihood