In [1]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from matplotlib.animation import FuncAnimation
from IPython import display

import torch

import pickle
import os

In [2]:
from lib.dataset_config_parser.trained_model_config_parser import parse_prefit_glm_paths
from optimization_encoder.trial_glm import load_fitted_glm_families
from denoise_inverse_alg.glm_inverse_alg import PackedGLMTensors, reinflate_cropped_glm_model
import lib.data_utils.dynamic_data_util as ddu

In [3]:
%matplotlib inline

In [4]:
BASEPATH='/Volumes/Backup/Scratch/Users/wueric/SUBMISSION_DATA_reconstruction/jitter/' # change this
HEIGHT = 160
WIDTH = 256

SAMPLES_PER_BIN = 20 # 1 ms bins @ 20 kHz sampling rate

In [None]:
device = torch.device('cuda:1') # change this for your machine

# Demonstration for jittered natural movie reconstruction

### Load the demonstration dataset

In [None]:
with open(os.path.join(BASEPATH, '2018_08_07_5_jittered_demo_data.p'), 'rb') as pfile:
    demo_dataset = pickle.load(pfile)

In [None]:
demo_data_dict = demo_dataset['data']
demo_data_metadata = demo_dataset['metadata']

In [None]:
demo_dataloader = ddu.DemoJitteredMovieDataloader(demo_data_dict)

### Load the models

In [None]:
# metadata keeping track of what cell is what
CELL_ORDERING_PATH = os.path.join(BASEPATH, 'pickles', 'reclassed.p')

################################################################
# Load the cell types and matching
with open(CELL_ORDERING_PATH, 'rb') as ordered_cells_file:
    cells_ordered = pickle.load(ordered_cells_file)  # type: OrderedMatchedCellsStruct
ct_order = cells_ordered.get_cell_types()

cell_ids_as_ordered_list = []
for ct in ct_order:
    cell_ids_as_ordered_list.extend(cells_ordered.get_reference_cell_order(ct))

In [None]:
# metadata for keeping track of the spatial filter bounding box for each cell
with open(os.path.join(BASEPATH, 'pickles', 'cropped_glm_bbox.pickle'), 'rb') as pfile:
    bounding_boxes_by_type = pickle.load(pfile)

In [None]:
# load the pre-fitted LNBRC models from disk
fitted_glm_paths = parse_prefit_glm_paths(os.path.join(BASEPATH, 'models', 'group0.yaml'))
fitted_glm_families = load_fitted_glm_families(fitted_glm_paths)

In [None]:
# pack the fitted models into tensors
packed_glm_tensors = reinflate_cropped_glm_model(
    fitted_glm_families,
    bounding_boxes_by_type,
    cells_ordered,
    HEIGHT,
    WIDTH,
    downsample_factor=demo_data_metadata['downsample_factor'],
    crop_width_low=demo_data_metadata['crop_width_low'],
    crop_width_high=demo_data_metadata['crop_width_high'],
    crop_height_low=demo_data_metadata['crop_height_low'],
    crop_height_high=demo_data_metadata['crop_height_high']
)

### Look at example model parameters

This is an example ON parasol

In [None]:
EXAMPLE_ID = 1046
EXAMPLE_PARAMS = fitted_glm_families['ON parasol'].fitted_models[EXAMPLE_ID]
EXAMPLE_BOUNDING_BOX = bounding_boxes_by_type['ON parasol'][cells_ordered.get_idx_for_same_type_cell_id_list('ON parasol', [EXAMPLE_ID, ])[0]]

In [None]:
timecourse_basis = fitted_glm_families['ON parasol'].timecourse_basis
feedback_basis = fitted_glm_families['ON parasol'].feedback_basis
coupling_basis = fitted_glm_families['ON parasol'].coupling_basis

In [None]:
full_spatial_filter = np.zeros((HEIGHT, WIDTH), dtype=np.float32)

putback_slice_obj_h, putback_slice_obj_w = EXAMPLE_BOUNDING_BOX.make_precropped_sliceobj(
    downsample_factor=demo_data_metadata['downsample_factor'],
    crop_wlow=demo_data_metadata['crop_width_low'],
    crop_whigh=demo_data_metadata['crop_width_high'],
    crop_hlow=demo_data_metadata['crop_height_low'],
    crop_hhigh=demo_data_metadata['crop_height_high'])

full_spatial_filter[putback_slice_obj_h, putback_slice_obj_w] = EXAMPLE_PARAMS.spatial_weights

In [None]:
LIM=2e-2
fig, ax = plt.subplots()
ax.imshow(full_spatial_filter, cmap='bwr', vmin=-LIM, vmax=LIM)
ax.set_xticks([])
ax.set_yticks([])
ax.axis('equal')
ax.set_title(f'ON parasol {EXAMPLE_ID} spatial stimulus filter')
plt.show()

In [None]:
timecourse_filter = (EXAMPLE_PARAMS.timecourse_weights @ timecourse_basis).squeeze(0)
fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])
ax.plot(np.r_[-250:0], timecourse_filter, color='red', lw=2)
ax.set_title(f'ON parasol {EXAMPLE_ID} temporal stimulus filter')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')
plt.show()

In [None]:
feedback_filter = (EXAMPLE_PARAMS.feedback_weights @ feedback_basis).squeeze(0)
fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])

ax.plot(np.r_[-250:0], feedback_filter[::-1], color='black', lw=2)

ax.set_title(f'ON parasol {EXAMPLE_ID} feedback filter')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')

plt.show()

In [None]:
coupling_weights, coupling_cells = EXAMPLE_PARAMS.coupling_cells_weights
coupling_filters = coupling_weights @ coupling_basis

COUPLED_CELL_TYPES = cells_ordered.get_cell_types()
coupled_filters_by_type = {ct: [] for ct in COUPLED_CELL_TYPES}

for ix, coupled_cell_id in enumerate(coupling_cells):
    coupled_type = cells_ordered.get_cell_type_for_cell_id(coupled_cell_id)
    coupled_filters_by_type[coupled_type].append(coupling_filters[ix, :])

In [None]:
DEMO_COUPLED_TYPE = 'ON parasol'
COUPLE_MAX = 1.0

fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])
ax.axhline(y=0, color='black', lw=0.5, xmin=0, xmax=1)

for cf in coupled_filters_by_type[DEMO_COUPLED_TYPE]:
    ax.plot(np.r_[-250:0], cf[::-1], lw=1)
ax.set_ylim([-COUPLE_MAX, COUPLE_MAX])

ax.set_title(f'Nearby {DEMO_COUPLED_TYPE} coupling filters')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')

plt.show()

In [None]:
DEMO_COUPLED_TYPE = 'OFF parasol'
COUPLE_MAX = 1.0

fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])
ax.axhline(y=0, color='black', lw=0.5, xmin=0, xmax=1)

for cf in coupled_filters_by_type[DEMO_COUPLED_TYPE]:
    ax.plot(np.r_[-250:0], cf[::-1], lw=1)
ax.set_ylim([-COUPLE_MAX, COUPLE_MAX])

ax.set_title(f'Nearby {DEMO_COUPLED_TYPE} coupling filters')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')

plt.show()

In [None]:
DEMO_COUPLED_TYPE = 'ON midget'
COUPLE_MAX = 1.0

fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])
ax.axhline(y=0, color='black', lw=0.5, xmin=0, xmax=1)

for cf in coupled_filters_by_type[DEMO_COUPLED_TYPE]:
    ax.plot(np.r_[-250:0], cf[::-1], lw=1)
ax.set_ylim([-COUPLE_MAX, COUPLE_MAX])

ax.set_title(f'Nearby {DEMO_COUPLED_TYPE} coupling filters')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')

plt.show()

In [None]:
DEMO_COUPLED_TYPE = 'OFF midget'
COUPLE_MAX = 1.0

fig, ax = plt.subplots(figsize=(5, 2))
ax.set_xlim([-252, 0])
ax.axhline(y=0, color='black', lw=0.5, xmin=0, xmax=1)

for cf in coupled_filters_by_type[DEMO_COUPLED_TYPE]:
    ax.plot(np.r_[-250:0], cf[::-1], lw=1)
ax.set_ylim([-COUPLE_MAX, COUPLE_MAX])

ax.set_title(f'Nearby {DEMO_COUPLED_TYPE} coupling filters')
ax.set_xlabel('Time [ms]')
ax.set_ylabel('Intensity [au]')

plt.show()

### Look at example stimulus

In [None]:
hist_frame, target_frame, frame_transitions, spike_bin_times, binned_spikes = demo_dataloader[0]

In [None]:
all_frames = np.concatenate([hist_frame, target_frame], axis=0)

fig, ax = plt.subplots()
ax.axis('off')
frame = ax.imshow(all_frames[0], vmin=-1, vmax=1, cmap='gray')


def animate(frame_num):
    frame.set_data(all_frames[frame_num])
    return frame

anim = FuncAnimation(fig, animate, frames=all_frames.shape[0], interval=25)
video = anim.to_html5_video()
html = display.HTML(video)
display.display(html)
plt.close()

### Look at example rasters

In [None]:
full_example_event_acc = []
for cell_ix in range(binned_spikes.shape[0]):
    
    spike_locations = np.argwhere(binned_spikes[cell_ix,:])[:, 0]
    full_example_event_acc.append(spike_locations)

In [None]:
fig, ax = plt.subplots()

ax.eventplot(full_example_event_acc, colors='black',linewidths=1)

ax.set_xlim(-100, 1000)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(False)
ax.set_yticks([])

ax.set_xticks([0, 250, 500, 750, 1000])
ax.set_xlim([0, 1000])

start_time = (frame_transitions[60] - frame_transitions[0]) / 20
p2 = mpatches.Rectangle([start_time, 0], *[500, 720], color="gray", clip_on=False, alpha=0.2)
ax.add_patch(p2)

ax.set_xlabel('Time [ms]')

plt.show()

### Simultaneously reconstruct images and estimate eye movements

(warning, these cells will take ~15 minutes to run on a modern GPU)

In [None]:
from dejitter_recons.joint_em_estimation import create_gaussian_multinomial
from generate_joint_eye_movements_reconstructions import make_get_iterators, \
    generate_joint_eye_movement_trajectory_reconstructions
from dejitter_recons.estimate_image import noreduce_nomask_batch_bin_bernoulli_neg_LL

In [None]:
NUM_PARTICLES = 10
gaussian_multinomial = create_gaussian_multinomial(1.2, 2)

RHO_START=0.01778279410038923
RHO_END=3.1622776601683795
PRIOR_WEIGHT=0.15
EYE_MOVEMENT_WEIGHT=0.5
HQS_NITER=5

ground_truth, reconstructions, eye_movement_trajectories = generate_joint_eye_movement_trajectory_reconstructions(
    packed_glm_tensors,
    demo_dataloader,
    SAMPLES_PER_BIN,
    30 * SAMPLES_PER_BIN,
    NUM_PARTICLES,
    gaussian_multinomial,
    PRIOR_WEIGHT,
    EYE_MOVEMENT_WEIGHT,
    noreduce_nomask_batch_bin_bernoulli_neg_LL,
    make_get_iterators(RHO_START, RHO_END, HQS_NITER),
    make_get_iterators(RHO_END, RHO_END, 1),
    demo_data_metadata['valid_region'],
    device,
    init_noise_sigma=None,
    em_inner_opt_verbose=False,
    throwaway_log_prob=-6,
    compute_image_every_n=10
)


In [None]:
fig, ax = plt.subplots()
ax.imshow(demo_data_metadata['valid_region'] * reconstructions[0, ...], vmin=-1.0, vmax=1.0, cmap='gray')
ax.axis('off')
plt.show()

### Generate known eye movements reconstructions

This cell should take about 1 minute to run on a modern GPU

In [None]:
from generate_fixed_eye_movements_reconstructions import batch_generate_known_eye_movement_trajectory_reconstructions

In [None]:
RHO_START=0.01778279410038923
RHO_END=3.1622776601683795
PRIOR_WEIGHT=0.15
HQS_NITER=5

ground_truth, known_eye_movements_reconstructions, known_eye_movements = batch_generate_known_eye_movement_trajectory_reconstructions(
    packed_glm_tensors,
    demo_dataloader,
    SAMPLES_PER_BIN,
    30 * SAMPLES_PER_BIN,
    PRIOR_WEIGHT,
    noreduce_nomask_batch_bin_bernoulli_neg_LL,
    make_get_iterators(RHO_START, RHO_END, HQS_NITER),
    demo_data_metadata['valid_region'],
    1,
    device,
    use_exact_eye_movements=True
)

In [None]:
fig, ax = plt.subplots()
ax.imshow(demo_data_metadata['valid_region'] * known_eye_movements_reconstructions[0, ...], 
          vmin=-1.0, vmax=1.0, cmap='gray')
ax.axis('off')
plt.show()

### Ignore eye movements while doing reconstruction

This cell should take about 1 minute to run on a modern GPU

In [None]:
RHO_START=0.01778279410038923
RHO_END=3.1622776601683795
PRIOR_WEIGHT=0.15
HQS_NITER=5

ground_truth, ignore_eye_movements_reconstructions, _ = batch_generate_known_eye_movement_trajectory_reconstructions(
    packed_glm_tensors,
    demo_dataloader,
    SAMPLES_PER_BIN,
    30 * SAMPLES_PER_BIN,
    PRIOR_WEIGHT,
    noreduce_nomask_batch_bin_bernoulli_neg_LL,
    make_get_iterators(RHO_START, RHO_END, HQS_NITER),
    demo_data_metadata['valid_region'],
    1,
    device,
    use_exact_eye_movements=False,
)

In [None]:
fig, ax = plt.subplots()
ax.imshow(demo_data_metadata['valid_region'] * ignore_eye_movements_reconstructions[0, ...], 
          vmin=-1.0, vmax=1.0, cmap='gray')
ax.axis('off')
plt.show()

### Plot all of the example images

In [None]:
convex_hull_mask_matrix_bool = demo_data_metadata['valid_region']

In [None]:
fig, axes = plt.subplots(ground_truth.shape[0], 4, figsize=(9, ground_truth.shape[0] * (8 / 5)))
for ix in range(ground_truth.shape[0]):
    
    ax = axes[ix][0]
    ax.imshow(ground_truth[ix, 20:, 16:-32] * convex_hull_mask_matrix_bool[20:, 16:-32], cmap='gray', vmin=-1.0, vmax=1.0)
    ax.axis('off')
    
    if ix == 0:
        ax.set_title("Stimulus", fontsize=14)
    
    ax = axes[ix][3]
    ax.imshow(reconstructions[ix, 20:, 16:-32] * convex_hull_mask_matrix_bool[20:, 16:-32], cmap='gray', vmin=-1.0, vmax=1.0)
    ax.axis('off')
    
    if ix == 0:
        ax.set_title("joint-LNBRC-dCNN", fontsize=14)
    
    ax = axes[ix][1]
    ax.imshow(known_eye_movements_reconstructions[ix, 20:, 16:-32] * convex_hull_mask_matrix_bool[20:, 16:-32], cmap='gray', vmin=-1.0, vmax=1.0)
    ax.axis('off')
    
    if ix == 0:
        ax.set_title("known-LNBRC-dCNN", fontsize=14)
    
    ax = axes[ix][2]
    ax.imshow(ignore_eye_movements_reconstructions[ix, 20:, 16:-32] * convex_hull_mask_matrix_bool[20:, 16:-32], cmap='gray', vmin=-1.0, vmax=1.0)
    ax.axis('off')
    
    if ix == 0:
        ax.set_title("zero-LNBRC-dCNN", fontsize=14)
    
plt.tight_layout()


plt.show()