In [None]:
import jtap_mice
jtap_mice.set_jaxcache()
from jtap_mice.model import full_init_model, full_step_model, likelihood_model, stepper_model, get_render_args,is_ball_in_valid_position, red_green_sensor_readouts
from jtap_mice.inference import run_jtap, run_parallel_jtap, JTAPMiceData
from jtap_mice.viz import rerun_jtap_stimulus, rerun_jtap_single_run, jtap_plot_rg_lines, red_green_viz_notebook, create_log_frequency_heatmaps
from jtap_mice.utils import load_red_green_stimulus, JTAPMiceStimulus, ChexModelInput, d2r, i_, f_, slice_pt, init_step_concat, discrete_obs_to_rgb, load_original_jtap_results, stack_pytrees, concat_pytrees
from jtap_mice.evaluation import JTAP_Decision_Model_Hyperparams, jtap_compute_beliefs, jtap_compute_decisions, jtap_compute_decision_metrics, JTAP_Metrics, JTAPMice_Beliefs, JTAP_Decisions, JTAP_Results
from jtap_mice.distributions import truncated_normal_sample, discrete_normal_sample
from jtap_mice.core import SuperPytree

import time
import rerun as rr
import genjax
from genjax import gen, ChoiceMapBuilder as C
import jax
import jax.numpy as jnp
from jax.debug import print as jprint
import numpy as np
from tqdm import tqdm
import jax.tree_util as jtu
from functools import partial
from matplotlib import pyplot as plt
from typing import List, Dict, Any, Tuple, NamedTuple

PIXEL_DENSITY = 10
SKIP_T = 4
COGSCI_TRIAL = 'E38'

stimulus_path = f'/home/arijitdasgupta/jtap/assets/stimuli/cogsci_2025_trials/{COGSCI_TRIAL}'

jtap_stimulus = load_red_green_stimulus(stimulus_path, pixel_density = PIXEL_DENSITY, skip_t = SKIP_T)

rgb_video_highres = load_red_green_stimulus(stimulus_path, pixel_density = PIXEL_DENSITY*5, skip_t = SKIP_T, rgb_only = True)

In [None]:
rerun_jtap_stimulus(discrete_obs = jtap_stimulus.discrete_obs, stimulus_name = "rg_stim1_discrete")

In [None]:
# Model_Input = ChexModelInput(
#     σ_pos=f_(0.75),
#     σ_speed=f_(0.15),
#     σ_NOCOL_direction=d2r(15),
#     σ_COL_direction=d2r(15),
#     # pixel_corruption_prob=f_(0.01),
#     pixel_corruption_prob=f_(0.45),
#     tile_size=i_(3),
#     σ_pixel_spatial=f_(1.0),
#     image_power_beta=f_(0.005),
#     max_speed=f_(1.0),
#     max_num_barriers=i_(10),
#     max_num_occ=i_(5),
#     num_x_grid=i_(6),
#     num_y_grid=i_(6),
#     grid_size_x=f_(0.2),
#     grid_size_y=f_(0.2),
#     max_num_col_iters=f_(2),
#     simulate_every=i_(1),
#     σ_pos_sim=f_(0.0005),
#     σ_speed_sim=f_(0.0005),
#     σ_NOCOL_direction_sim=d2r(0.8),
#     σ_COL_direction_sim=d2r(1.6),
#     σ_speed_occ=f_(0.0005),
#     σ_NOCOL_direction_occ=d2r(0.8),
#     σ_COL_direction_occ=d2r(1.6),
#     σ_pos_initprop=f_(0.02),
#     σ_speed_initprop=f_(0.1),
#     σ_speed_stepprop=f_(0.04),
#     σ_NOCOL_direction_initprop=d2r(0.3),
#     σ_NOCOL_direction_stepprop=d2r(4.0),
#     σ_COL_direction_prop=d2r(0.3),
#     σ_pos_stepprop=f_(0.01)
# )

In [None]:
dir_val = 2
speed_val = 0.05

Model_Input = ChexModelInput(
    σ_pos=f_(0.05),
    σ_speed=f_(speed_val),
    σ_NOCOL_direction=d2r(dir_val),
    σ_COL_direction=d2r(dir_val),
    pixel_corruption_prob=f_(0.45),
    tile_size=i_(3),# can ignore
    σ_pixel_spatial=f_(1.0),# can ignore
    image_power_beta=f_(0.005),# can ignore
    max_speed=f_(1.0),# can ignore
    max_num_barriers=i_(10),# can ignore
    max_num_occ=i_(5),# can ignore
    num_x_grid=i_(6),
    num_y_grid=i_(6),
    grid_size_x=f_(1.0),
    grid_size_y=f_(1.0),
    max_num_col_iters=f_(2), # can ignore
    simulate_every=i_(1), # can ignore
    σ_pos_sim=f_(0.01),
    σ_speed_sim=f_(0.03),
    σ_NOCOL_direction_sim=d2r(0.01),
    σ_COL_direction_sim=d2r(0.01),
    σ_speed_occ=f_(0.0005),# can ignore
    σ_NOCOL_direction_occ=d2r(0.8),# can ignore
    σ_COL_direction_occ=d2r(0.8),# can ignore
    σ_pos_initprop=f_(0.02), # not consequential
    σ_speed_initprop=f_(speed_val),
    σ_speed_stepprop=f_(speed_val),
    σ_NOCOL_direction_initprop=d2r(dir_val*2),
    σ_NOCOL_direction_stepprop=d2r(dir_val*2),
    σ_COL_direction_prop=d2r(dir_val*2),
    σ_pos_stepprop=f_(0.01) # can ignore
)
# PREPARE INPUT
Model_Input.prepare_hyperparameters()
Model_Input.prepare_scene_geometry(jtap_stimulus)

smc_key_seed = np.random.randint(0, 1000000)
ESS_proportion = 0.05
num_particles = 50

In [None]:
start_time = time.time()
JTAPMICE_DATA = run_parallel_jtap(50, smc_key_seed, Model_Input, ESS_proportion, jtap_stimulus, num_particles)
end_time = time.time()
print(f"Time taken for parallel JTAP: {end_time - start_time} seconds")

In [None]:
JTAPMice_Beliefs = jtap_compute_beliefs(JTAPMICE_DATA)
jtap_plot_rg_lines(JTAPMice_Beliefs, stimulus = jtap_stimulus, show = "model", include_baselines=True, remove_legend=True, show_std_band=True)

In [None]:
jtap_decision_model_hyperparams = JTAP_Decision_Model_Hyperparams(
    key_seed = 123,
    pseudo_participant_multiplier = 500,
    press_thresh_hyperparams = (0.3, 0.2, 0.1, 0.6),
    tau_press_hyperparams = (4.0, 1.0, np.arange(2, 6)),
    hysteresis_delay_hyperparams = (2.5, 1.0, np.arange(1, 4)),
    regular_delay_hyperparams = (2.5, 0.5, np.arange(2, 4)),
    starting_delay_hyperparams = (7.0, 4.0, np.arange(3, 10))
)
jtap_decisions = jtap_compute_decisions(JTAPMice_Beliefs, jtap_decision_model_hyperparams, use_old_logic = False)
jtap_metrics = jtap_compute_decision_metrics(jtap_decisions, jtap_stimulus, partial_occlusion_in_targeted_analysis=True, ignore_uncertain_line=True)
jtap_plot_rg_lines(jtap_decisions, stimulus = jtap_stimulus, show = "model", include_baselines=True, include_human=True, jtap_metrics=jtap_metrics)

In [None]:
# fig = create_log_frequency_heatmaps(jtap_metrics, targeted_analysis=True)
# fig.show()

In [None]:
rerun_jtap_single_run(JTAPMICE_DATA, rgb_video_highres = None, stimulus_name = "jtap_single_runv3", tracking_dot_size_range = (0.01,3), prediction_line_size_range = (0.01,0.8), jtap_run_idx = 12)

In [None]:
JTAPMICE_DATA.inference.resampled