In [None]:
# --- Colab: Minimal DDM widget (t-play + drift/noise sliders), tuned for <1s hits ---

import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
from ipywidgets import VBox, HBox, Layout
from IPython.display import display, clear_output

# If widgets don't show in Colab, run once:
# !pip -q install ipywidgets
# from google.colab import output
# output.enable_custom_widget_manager()

# --------------------------
# Defaults tuned for <1s hits
# --------------------------
DEFAULTS = dict(
    mu=0.35,      # drift
    sigma=0.75,   # noise
    theta=0.25,   # bound
)

dt = 0.001
T = 1.0  # timeline shown and controlled by widget

# --------------------------
# Simulation
# --------------------------
def simulate_ddm(mu, sigma, theta, dt=dt, T=T, seed=None):
    rng = np.random.default_rng(seed)
    n = int(np.round(T / dt)) + 1
    t = np.arange(n) * dt

    x = np.zeros(n, dtype=float)  # start at 0
    hit = None

    for i in range(1, n):
        x[i] = x[i-1] + mu * dt + sigma * np.sqrt(dt) * rng.normal()
        if hit is None and abs(x[i]) >= theta:
            hit = i
            x[i:] = x[i]  # absorb
            break

    return t, x, hit

# --------------------------
# Widgets (small ranges)
# --------------------------
mu_slider = widgets.FloatSlider(
    value=DEFAULTS["mu"], min=-0.8, max=0.8, step=0.02,
    description='drift μ', continuous_update=False,
    layout=Layout(width='300px')
)

sigma_slider = widgets.FloatSlider(
    value=DEFAULTS["sigma"], min=0.2, max=1.2, step=0.02,
    description='noise σ', continuous_update=False,
    layout=Layout(width='300px')
)

theta_slider = widgets.FloatSlider(
    value=DEFAULTS["theta"], min=0.15, max=0.5, step=0.01,
    description='bound θ', continuous_update=False,
    layout=Layout(width='300px')
)

new_trial_btn = widgets.Button(
    description='New trial (reset)', button_style='primary',
    layout=Layout(width='160px')
)

k_slider = widgets.FloatSlider(
    value=0.0, min=0.0, max=T, step=dt,
    description='time (s)', continuous_update=True,
    layout=Layout(width='460px')
)

out = widgets.Output()

# --------------------------
# State + drawing
# --------------------------
state = {"t": None, "x": None, "hit": None, "seed": 0, "suspend": False}

def regen():
    state["seed"] += 1
    t, x, hit = simulate_ddm(mu_slider.value, sigma_slider.value, theta_slider.value, seed=state["seed"])
    state["t"], state["x"], state["hit"] = t, x, hit

def draw(t_sec):
    with out:
        clear_output(wait=True)

        t = state["t"]
        x = state["x"]
        theta = theta_slider.value
        hit = state["hit"]

        k = int(np.round(t_sec / dt))
        k = max(0, min(k, len(t) - 1))

        # If bound was hit, freeze trajectory display at hit index
        k_eff = min(k, hit) if (hit is not None) else k

        fig, ax = plt.subplots(figsize=(6.6, 2.8))

        # bounds
        ax.axhline(+theta, linestyle='--', linewidth=1)
        ax.axhline(-theta, linestyle='--', linewidth=1)

        # trajectory up to current time
        ax.plot(t[:k_eff+1], x[:k_eff+1], linewidth=2)
        ax.scatter([t[k_eff]], [x[k_eff]], s=26)

        # time counter on the right
        ax.text(0.99, 0.92, f"t = {t[k_eff]:.3f} s",
                transform=ax.transAxes, fontsize=13, va='top', ha='right')

        # minimal axes
        ax.set_xlim(0, T)
        ax.set_ylim(-1.2 * theta, 1.2 * theta)

        ax.set_yticks([-theta, 0.0, +theta])
        ax.set_yticklabels(["-θ", "0", "+θ"])

        ax.set_xticks([0.0, 0.5, 1.0])
        ax.set_xlabel("time (s)")
        ax.set_ylabel("DV")

        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)

        # minimal margins: mostly top/bottom
        fig.subplots_adjust(left=0.10, right=0.99, top=0.95, bottom=0.24)

        plt.show()

def safe_reset_all(_=None):
    state["suspend"] = True
    try:
        # reset params
        mu_slider.value = DEFAULTS["mu"]
        sigma_slider.value = DEFAULTS["sigma"]
        theta_slider.value = DEFAULTS["theta"]

        # new trajectory
        regen()

        # reset time
        k_slider.value = 0.0
    finally:
        state["suspend"] = False
    draw(0.0)

def on_param_change(_):
    if state["suspend"]:
        return
    state["suspend"] = True
    try:
        regen()
        k_slider.value = 0.0
    finally:
        state["suspend"] = False
    draw(0.0)

def on_time_change(change):
    if state["suspend"]:
        return
    draw(change["new"])

mu_slider.observe(on_param_change, names='value')
sigma_slider.observe(on_param_change, names='value')
theta_slider.observe(on_param_change, names='value')
k_slider.observe(on_time_change, names='value')
new_trial_btn.on_click(safe_reset_all)

# init
regen()
draw(0.0)

controls = HBox([mu_slider, sigma_slider, theta_slider, new_trial_btn])
time_controls = HBox([k_slider])

display(VBox([controls, time_controls, out]))

VBox(children=(HBox(children=(FloatSlider(value=0.35, continuous_update=False, description='drift μ', layout=L…