In [None]:
# import os
# os.environ["JAX_PLATFORM_NAME"] = "cpu"

import jtap
jtap.set_jaxcache()
from jtap.utils import init_step_concat

# Third-party imports
import copy
import jax
import jax.numpy as jnp
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm

# GenJAX imports
from genjax import ChoiceMapBuilder as C

# LR imports
from lr_generation_code import *

In [None]:
# CONFIGURE THE FOLLOWING VARIABLES TO GENERATE STIMULI PROCEDURALLY
###########################################################
# START OF CONFIGURATION
###########################################################

# ==== EXPERIMENT METADATA ====
STIMULI_NAME = 'lr_v1'
INITIAL_RANDOM_SEED = 42

# ==== TEMPORAL PARAMETERS ====
FRAMES_PER_SECOND = 20
MAX_TRIAL_SECONDS = 10
MIN_TRIAL_SECONDS = 6
MIN_SECONDS_BETWEEN_SWITCHES = None  # Set to None (or 0) to disable
MAX_NUM_SWITCHES = None  # Set to None (or a very large number) to disable

# ==== SWITCHING BEHAVIOR ====
USE_SEMI_MARKOV_SWITCHING = False

# Semi-Markov switching distribution type and parameters
# Options: 'geometric', 'negative_binomial', 'discrete_normal'

#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~``
# NOTE: The geometric distribution basically reduces to the regular markov process with a probability of flipping the direction per frame (i.e. DIRECTION_FLIP_PROB). Basically, it's as if you set USE_SEMI_MARKOV_SWITCHING to False.
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~``
SEMI_MARKOV_DISTRIBUTION_TYPE = 'discrete_normal'

# Parameters for each distribution type (only the relevant ones will be used based on SEMI_MARKOV_DISTRIBUTION_TYPE)
# For 'geometric': only 'prob' is used
# For 'negative_binomial': 'rate' and 'prob' are used
# For 'discrete_normal': 'mean' and 'std' are used
SEMI_MARKOV_DISTRIBUTION_PARAMS = {
    'rate': 75.0,    # Rate parameter for negative binomial (n parameter)
    'prob': 0.3,     # Probability parameter for geometric and negative binomial (p parameter)
    'mean': 100.0,     # Mean parameter for discrete normal distribution
    'std': 75.0       # Standard deviation parameter for discrete normal distribution
}

# Non-semi-Markov switching parameters (only used if USE_SEMI_MARKOV_SWITCHING is False)
DIRECTION_FLIP_PROB = 0.025

# ==== SPATIAL PARAMETERS ====
SPEED = 0.065  # Ball speed in units of ball diameters per frame
LEFT_RIGHT_LENGTH = 13  # Track length in units of ball diameters

# Starting position configuration
UNIFORM_STARTING_POSITION = True
UNIFORM_STARTING_FRACTIONAL_RANGE_FROM_MIDPOINT = 1.0  # Fractional range from midpoint (0.5 = 25%-75% of track)
UNIFORM_STARTING_POSITION_BINNING_LENGTH = 0.5  # Bin size for uniform sampling in ball diameters
NON_UNIFORM_STARTING_POSITION = None  # Fixed starting position if not using uniform (None = middle of track)

# ==== TRIAL GENERATION ====
REQUESTED_NUM_TRIALS = 528
BATCH_SIZE = 1000  # Batch size for simulation
TOTAL_NUM_TRIALS = 100000  # Total number of trials to generate during simulation

###########################################################
# END OF CONFIGURATION
###########################################################

#### Semi-Markov Switching Behavior (Discrete-time renewal process) via different duration distributions 

The plot below visualizes the distribution of sojourn duration times (which is the time between switches for us) which can be used as the duration distribition for a semi-markov process.

In [None]:
# Uncomment any one of the plots below to visualize the distribution for time between frames.

# create_negative_binomial_interactive_plot()

# create_geometric_interactive_plot()

create_truncated_normal_interactive()

## Run Cell Below to Generate Trials

In [None]:
# Generate stimuli and trials using the extracted function
result = generate_stimuli_and_trials(
    stimuli_name=STIMULI_NAME,
    initial_random_seed=INITIAL_RANDOM_SEED,
    frames_per_second=FRAMES_PER_SECOND,
    max_trial_seconds=MAX_TRIAL_SECONDS,
    min_trial_seconds=MIN_TRIAL_SECONDS,
    min_seconds_between_switches=MIN_SECONDS_BETWEEN_SWITCHES,
    max_num_switches=MAX_NUM_SWITCHES,
    use_semi_markov_switching=USE_SEMI_MARKOV_SWITCHING,
    semi_markov_distribution_type=SEMI_MARKOV_DISTRIBUTION_TYPE,
    semi_markov_distribution_params=SEMI_MARKOV_DISTRIBUTION_PARAMS,
    direction_flip_prob=DIRECTION_FLIP_PROB,
    speed=SPEED,
    left_right_length=LEFT_RIGHT_LENGTH,
    uniform_starting_position=UNIFORM_STARTING_POSITION,
    uniform_starting_fractional_range_from_midpoint=UNIFORM_STARTING_FRACTIONAL_RANGE_FROM_MIDPOINT,
    uniform_starting_position_binning_length=UNIFORM_STARTING_POSITION_BINNING_LENGTH,
    non_uniform_starting_position=NON_UNIFORM_STARTING_POSITION,
    requested_num_trials=REQUESTED_NUM_TRIALS,
    batch_size=BATCH_SIZE,
    total_num_trials=TOTAL_NUM_TRIALS
)

if result is not None:
    # Extract all the variables that were previously global
    final_trial_data = result['final_trial_data']
    final_switch_data = result['final_switch_data']
    all_positions_over_time = result['all_positions_over_time']
    all_ending_indices = result['all_ending_indices']
    all_ending_indices_inclusive = result['all_ending_indices_inclusive']
    all_same_side_over_time = result['all_same_side_over_time']
    trial_indices_dict = result['trial_indices_dict']
    hypers = result['hypers']
    config = result['config']
    uniform_range_start = result['uniform_range_start']
    uniform_range_end = result['uniform_range_end']
    uniform_range_length = result['uniform_range_length']
    num_uniform_bins = result['num_uniform_bins']
    num_valid_trials_per_bin = result['num_valid_trials_per_bin']
    enough_trials = result['enough_trials']
    mid_point = result['mid_point']
    diameter = result['diameter']
    selected_trial_indices = result['selected_trial_indices']
else:
    print("Failed to generate trials - insufficient trials for counterbalancing")


In [None]:
import json
import os

# Create the data structure to save using the config from the generation result
if result is not None:
    data_to_save = {
        'config': config,
        'trial_data': {str(k): v.tolist() for k, v in final_trial_data.items()}
    }
    
    # Create filename
    filename = f"{STIMULI_NAME}.json"
    
    # Save to JSON file
    with open(filename, 'w') as f:
        json.dump(data_to_save, f, indent=2)
    
    print(f"Data saved to {filename}")
    print(f"File contains {len(final_trial_data)} trials with configuration parameters")
else:
    print("No data to save - generation failed")


In [None]:
# Use this to visualize the trial data from the json file.
# the default path is set to the stimuli name in this notebook, you can change it.
# NOTE: Trial number is 1-indexed in the json file. So trial 1 is the first trial.
viz_trial(json_path = f"{STIMULI_NAME}.json", trial_number = 3, figure_size = 18, pixel_density = 50)

## Below are some plotting functions to help visualize the spread of trials by key parameters

In [None]:
# Plot histogram of trial ending times distribution using the extracted function
if result is not None:
    plot_trial_ending_times_distribution(
        all_ending_indices_inclusive, 
        FRAMES_PER_SECOND, 
        MAX_TRIAL_SECONDS, 
        MIN_TRIAL_SECONDS
    )
else:
    print("No data available for plotting - generation failed")


In [None]:
# Plot starting position vs trial outcome using the extracted function
if result is not None:
    plot_starting_position_vs_trial_outcome(
        trial_indices_dict,
        uniform_range_start,
        UNIFORM_STARTING_POSITION_BINNING_LENGTH,
        num_uniform_bins,
        num_valid_trials_per_bin,
        enough_trials,
        LEFT_RIGHT_LENGTH,
        mid_point,
        diameter
    )
else:
    print("No data available for plotting - generation failed")


In [None]:
# Plot switching behavior distribution using the extracted function
if result is not None:
    plot_switching_behavior_distribution(
        final_switch_data,
        REQUESTED_NUM_TRIALS,
        FRAMES_PER_SECOND
    )
else:
    print("No data available for plotting - generation failed")