In [None]:
import matplotlib.pyplot as plt
import numpy as np

from hrl.model import DecisionModel, RLDecisionModel



In [None]:
def generate_stimuli(
    lower_bounds: float | np.ndarray, upper_bounds: float | np.ndarray, num_trials: int
) -> np.ndarray:
    """
    Generate experimental stimuli using a uniform distribution.

    Args:
        lower_bounds: Lower bounds of each type of stimulus.
        upper_bounds: Upper bounds of each type of stimulus.
        num_trials: Number of trials to generate stimuli for.
    """
    num_stim = np.broadcast(lower_bounds, upper_bounds).size
    return np.random.uniform(lower_bounds, upper_bounds, size=(num_trials, num_stim))



In [None]:
# true_model = LogisticDecisionModel(bias=0, weights=10)
true_model = RLDecisionModel(mu=12.5, sigma=1, gamma_l=0.1, gamma_h=0.2)
stimuli = generate_stimuli(9, 16, 10000)
actions = true_model.simulate(stimuli)

print(true_model)
print(stimuli.shape)
print(stimuli[:10])
print(actions.shape)
print(actions[:10])



In [None]:
x = np.linspace(stimuli.min(), stimuli.max(), 100)
y = true_model.action_probabilities(x)[:, 1]
plt.plot(x, y)
plt.ylim(0, 1)


In [None]:
fit_model = type(true_model)()
fit_model.fit(stimuli, actions)
print(fit_model)



In [None]:
def plot_log_likelihood_vs_num_trials(
    true_model: DecisionModel,
    stimuli: np.ndarray,
    actions: np.ndarray | None = None,
    num_trials_to_test: np.ndarray = np.array([10, 30, 100, 300, 1000, 3000, 10000]),
) -> np.ndarray:
    """
    Plot the log likelihood of a model as a function of the number of trials.

    Args:
        true_model: The model used to simulate the decisions.
        stimuli: The stimuli presented to the subject.
        actions: The actions taken by the subject. If None, the actions are simulated.
        num_trials: The number of trials to use for each log likelihood computation.
    """
    if len(stimuli) < num_trials_to_test.max():
        raise ValueError("Not enough stimuli to test.")
    if actions is None:
        actions = true_model.simulate(stimuli)
    log_likelihoods = []
    for num_trials in num_trials_to_test:
        fit_model = type(true_model)()
        stim_i = stimuli[:num_trials]
        action_i = actions[:num_trials]
        fit_model.fit(stim_i, action_i)
        log_likelihoods.append(fit_model.log_likelihood(stimuli, actions))
    plt.semilogx(num_trials_to_test, log_likelihoods)
    plt.axhline(
        true_model.log_likelihood(stimuli, actions),
        color="red",
        linestyle="--",
        label="True Model",
    )
    plt.xlabel("# Trials")
    plt.ylabel("Log Likelihood")
    plt.legend()
    plt.show()


def plot_params_vs_num_trials(
    true_model: DecisionModel,
    stimuli: np.ndarray,
    actions: np.ndarray | None = None,
    num_trials_to_test: np.ndarray = np.array([10, 30, 100, 300, 1000, 3000, 10000]),
    num_repeats: int = 10,
) -> np.ndarray:
    """
    Plot the parameters of a model as a function of the number of trials.

    Args:
        true_model: The model used to simulate the decisions.
        stimuli: The stimuli presented to the subject.
        actions: The actions taken by the subject. If None, the actions are simulated.
        num_trials: The number of trials to use for each log likelihood computation.
        num_repeats: The number of times to repeat the fit for each number of trials.
    """
    if len(stimuli) < num_trials_to_test.max():
        raise ValueError("Not enough stimuli to test.")
    if actions is None:
        actions = true_model.simulate(stimuli)
    true_params = true_model.params
    params = []
    for num_trials in num_trials_to_test:
        param_i = []
        for _ in range(num_repeats):
            stim_i = stimuli[:num_trials]
            action_i = actions[:num_trials]
            fit_model = type(true_model)()
            fit_model.fit(stim_i, action_i)
            param_i.append(fit_model.params)
        params.append(np.array(param_i))
    params = np.array(params)
    params = params.reshape(-1, params.shape[-1])
    params += np.random.normal(0, 0.01, size=params.shape)
    x = np.repeat(num_trials_to_test, num_repeats)
    fig, axes = plt.subplots(1, params.shape[-1], figsize=(params.shape[-1] * 5, 5))
    for i in range(params.shape[-1]):
        axes[i].scatter(x, params[:, i])
        axes[i].axhline(true_params[i], color="red", linestyle="--", label="True Model")
        axes[i].set_xscale("log")
        axes[i].set_xlabel("# Trials")
        axes[i].set_ylabel(fit_model.param_names[i])



In [None]:
plot_log_likelihood_vs_num_trials(true_model, stimuli, actions)


In [None]:
plot_params_vs_num_trials(true_model, stimuli, actions)

