In [1]:
import jtap_mice
jtap_mice.set_jaxcache()
from jtap_mice.inference import run_jtap, run_parallel_jtap, JTAPMiceData
from jtap_mice.viz import rerun_jtap_stimulus, rerun_jtap_single_run, jtap_plot_rg_lines, red_green_viz_notebook, plot_proposal_direction_outlier_pdf, draw_stimulus_image
from jtap_mice.utils import load_left_right_stimulus, JTAPMiceStimulus, ChexModelInput, d2r, i_, f_, slice_pt, init_step_concat, discrete_obs_to_rgb, stack_pytrees, concat_pytrees
from jtap_mice.evaluation import jtap_compute_beliefs, JTAPMice_Beliefs

import time
import rerun as rr
import genjax
from genjax import gen, ChoiceMapBuilder as C
import jax
import jax.numpy as jnp
from jax.debug import print as jprint
import numpy as np
from tqdm import tqdm
import jax.tree_util as jtu
from functools import partial
from matplotlib import pyplot as plt
from typing import List, Dict, Any, Tuple, NamedTuple

PIXEL_DENSITY = 10
SKIP_T = 4

In [2]:
LR_CONFIG= 'lr_v1'
LR_TRIAL = 1

stimulus_path = f'/home/arijitdasgupta/jtap-mice/notebooks/left_right_task/{LR_CONFIG}.json'

jtap_stimulus = load_left_right_stimulus(stimulus_path, pixel_density = PIXEL_DENSITY, skip_t = SKIP_T, trial_number = LR_TRIAL)

rgb_video_highres = load_left_right_stimulus(stimulus_path, pixel_density = PIXEL_DENSITY*5, skip_t = SKIP_T, rgb_only = True, trial_number = LR_TRIAL)

In [3]:
rerun_jtap_stimulus(discrete_obs = jtap_stimulus.discrete_obs, stimulus_name = "rg_stim1_discrete")

In [4]:
# Parameters based on default values in ChexModelInput @datastrucs.py

direction_flip_prob = 0.025

Model_Input = ChexModelInput(
    model_outlier_prob = 0.05,
    proposal_direction_outlier_tau = d2r(40.),
    proposal_direction_outlier_alpha = 3.5,
    σ_pos=10000.0,
    σ_speed=0.075,
    model_direction_flip_prob=direction_flip_prob,
    pixel_corruption_prob=0.01,
    tile_size=3,
    σ_pixel_spatial=1.0,
    image_power_beta=0.005,
    max_speed=1.0,
    max_num_occ=5,
    num_x_grid=8,
    num_y_grid=8,
    grid_size_bounds=(0.05, 0.95),
    simulate_every=1,
    σ_pos_sim=0.05,
    σ_speed_sim=0.075,
    σ_direction_sim=d2r(5.0),
    σ_pos_initprop=0.02,
    σ_direction_stepprop_flip_prob=direction_flip_prob,
    σ_speed_stepprop=0.3,
    σ_pos_stepprop=0.01
)
# PREPARE INPUT
Model_Input.prepare_hyperparameters()
Model_Input.prepare_scene_geometry(jtap_stimulus)

ESS_proportion = 0.09
smc_key_seed = np.random.randint(0, 1000000)
num_particles = 50

In [6]:
num_jtap_runs = 50

start_time = time.time()
JTAPMICE_DATA, xx = run_parallel_jtap(num_jtap_runs, smc_key_seed, Model_Input, ESS_proportion, jtap_stimulus, num_particles)
end_time = time.time()
mean_ESS = np.mean(JTAPMICE_DATA.inference.ESS.mean(axis=1))
print(f"Mean ESS: {100 * mean_ESS / num_particles:.1f}% of {num_particles} particles")
resampled = JTAPMICE_DATA.inference.resampled
resampled_pct = 100 * np.mean(resampled)
print(f"Resampling occurred in {resampled_pct:.1f}% of {jtap_stimulus.num_frames} frames")
print(f"Time taken for parallel JTAP: {end_time - start_time} seconds")

Mean ESS: 20.7% of 50 particles
Resampling occurred in 28.4% of 181 frames
Time taken for parallel JTAP: 1.4455361366271973 seconds


In [7]:
JTAPMICE_DATA

In [None]:
JTAPMice_Beliefs = jtap_compute_beliefs(JTAPMICE_DATA)
jtap_run_idx = None
show_all_lines = False
show_std_band = True
jtap_plot_rg_lines(JTAPMice_Beliefs, stimulus = jtap_stimulus, show = "model", include_baselines=True, remove_legend=True, show_std_band=show_std_band, jtap_run_idx = jtap_run_idx, include_start_frame=True, show_all_beliefs=show_all_lines, plot_stat = "median", include_stimulus=True)

In [None]:
DECISION_MODEL_VERSION = "v4"

jtap_decision_model_hyperparams = JTAP_Decision_Model_Hyperparams(
    key_seed=123,
    pseudo_participant_multiplier=50,
    press_thresh_hyperparams=(0.5253067868385957, 0.030245679028864637, 0.0, 1.0),
    tau_press_hyperparams=(1.0, 0.01, np.arange(30)),
    hysteresis_delay_hyperparams=None,
    regular_delay_hyperparams=(4.0, 1.2362500260780824, np.arange(30)),
    starting_delay_hyperparams=(11.0, 3.794423573053696, np.arange(30))
)

jtap_decisions, jtap_decision_model_params = jtap_compute_decisions(
    JTAPMice_Beliefs,
    jtap_decision_model_hyperparams,
    decision_model_version=DECISION_MODEL_VERSION
)
# jtap_metrics = jtap_compute_decision_metrics(
#     jtap_decisions,
#     jtap_stimulus,
#     partial_occlusion_in_targeted_analysis=True,
#     ignore_uncertain_line=True
# )

jtap_metrics = None

jtap_plot_rg_lines(
    jtap_decisions,
    stimulus=jtap_stimulus,
    show="model",
    include_baselines=True,
    include_human=True,
    jtap_metrics=jtap_metrics,
    jtap_run_idx=jtap_run_idx,
    remove_legend=True,
    include_start_frame=True,
    plot_stat="mean",
    include_stimulus=True
)

In [None]:
jtap_run_idx_viz = 9

rerun_jtap_single_run(JTAPMICE_DATA, rgb_video_highres = None, stimulus_name = "jtap_single_runv3", tracking_dot_size_range = (0.5,2), prediction_line_size_range = (0.05,0.4), jtap_run_idx = jtap_run_idx_viz, grid_dot_radius = 0.3, render_grid = True, show_velocity = True)

In [None]:
draw_stimulus_image(jtap_stimulus, frame=1)