In [None]:
import pandas as pd
import numpy as np
from utils2 import RTModel
import matplotlib.pyplot as plt
import arviz as az
import seaborn as sns
import jax
from numpyro.infer import Predictive
from numpyro import distributions as dist
import os
from plot_utils import plot_posterior, plot_RT_predictions

In [None]:
data_path = 'exp1_test(in).csv'
dataname = 'logRT_norm'
palette = sns.color_palette("tab10", 6)  # 6 distinct colors, one for each rep

In [None]:
from numpyro import distributions as dist

prior_options = {
    "a": [
        dist.Normal(0, 1),
    ],
    "b": [
        dist.HalfCauchy(1.0),
    ],
    "sig_t": [
        dist.HalfNormal(0.25),
    ],
    "sigma": [
        dist.HalfCauchy(1.0),
    ],
}


In [None]:
import itertools
keys = prior_options.keys()
values = prior_options.values()
combinations = list(itertools.product(*values))

In [None]:
def prior_dict_to_str(prior_dict):
    import numbers

    parts = []
    for k, v in prior_dict.items():
        dist_name = v.__class__.__name__.replace("Distribution", "")

        loc = getattr(v, "loc", None)
        scale = getattr(v, "scale", None)

        def format_param(param):
            try:
                if isinstance(param, numbers.Number):
                    return f"{param:.2f}"
                param_val = float(param)
                return f"{param_val:.2f}"
            except Exception:
                return "NA"  # safer than '?'

        loc_str = format_param(loc)
        scale_str = format_param(scale)

        parts.append(f"{k}-{dist_name}({loc_str},{scale_str})")
    return "_".join(parts)


In [None]:
from joblib import Parallel, delayed

def run_model_combo(ot, phase2model, group2model, ns,nc, prior_idx, combo):
    from utils2 import RTModel
    import jax
    from numpyro.infer import Predictive
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    import pandas as pd
    import os

    dataname = 'logRT_norm'
    palette = sns.color_palette("tab10", 6)

    if ot == 1:
        output_folder = 'all'
        model = RTModel(dataname, phase2model, group2model, 'exp1_test(in).csv', model_name='model')
   
    keys = prior_options.keys()
    prior_dict = dict(zip(keys, combo))
    filename_prior_str = prior_dict_to_str(prior_dict)

    csv_dir = f"Results/{output_folder}_{filename_prior_str}"
    os.makedirs(f"{csv_dir}", exist_ok=True)
    csv_file = f"{csv_dir}/post_group{group2model}_phase{phase2model}_sample{ns}_chain{nc}.csv"

    if os.path.exists(csv_file):
        return  # Skip this iteration if file already exists

    print(f"Running MCMC: group={group2model}, phase={phase2model}, prior#{prior_idx+1}")
    
    mcmc = model.run_mcmc(num_warmup=500, num_samples=ns, num_chains=nc, prior=prior_dict, rng_key=ot+phase2model+group2model)

    prior_rng = jax.random.PRNGKey(1)
    pred_rng = jax.random.PRNGKey(2)
    model_func = model.model["func"]
    prior = Predictive(model_func, num_samples=500)(
        prior_rng, data=None, repi_idx=model.repi_idx, n_subjects=len(model.subjects))
    posterior_samples = mcmc.get_samples()
    ps = pd.DataFrame(posterior_samples)

    ps.to_csv(csv_file, index=False)


# ---------- (4) POSTERIOR predictive ----------
    posterior_predictive = Predictive(
        model_func,
        posterior_samples=posterior_samples,
        return_sites=["rt_obs"],
        num_samples=5000,
    )(
        pred_rng,
        data=None,             # 同样设 None
        repi_idx=model.repi_idx,
        n_subjects=len(model.subjects)
    )
                                
    a_samples = posterior_samples['a']
    b_samples = posterior_samples['b']
    mu_t_samples = posterior_samples['mu_t']  # this should be an array
    peak_t = RTModel.get_mode_kde(mu_t_samples) # peak mu
    sig_t = posterior_samples['sig_t'].mean()
    sigma_samples = posterior_samples['sigma']
    sigma = np.mean(sigma_samples, axis=0)

    nb = 35
    # Compute stats
    a_mean = np.mean(a_samples)
    a_median = np.median(a_samples)
    a_peak = RTModel.get_mode_kde(a_samples)
    mu_t_mean = np.mean(mu_t_samples)
    mu_t_median = np.median(mu_t_samples)

    b_mean = np.mean(b_samples)
    b_median = np.median(b_samples)
    b_peak = RTModel.get_mode_kde(b_samples)

    t_star = np.round(peak_t).astype("int32") 
    mu_t_fun = lambda repi_idx: np.where(repi_idx < t_star+1, a_peak  - b_peak * repi_idx, a_peak  - b_peak * t_star)
    repi_idx_all = np.arange(6)
    mu_t_all = mu_t_fun(repi_idx_all)

    # Create subplots
    fig, axes = plt.subplots(2, 3, figsize=(10, 6))

    # --- Plot 1: Posterior of mu_t ---
    sns.histplot(mu_t_samples, bins=nb, kde=True, stat="density", color='skyblue', ax=axes[0, 0])
    axes[0, 0].set_title("Posterior Distribution of $\mu_t$")
    axes[0, 0].set_xlabel("$\mu_t$")
    axes[0, 0].set_ylabel("PDF")
    axes[0, 0].grid(True)
    axes[0, 0].axvline(mu_t_mean, color='red', linestyle='--', label=f'Mean = {mu_t_mean:.2f}')
    axes[0, 0].axvline(mu_t_median, color='green', linestyle=':', label=f'Median = {mu_t_median:.2f}')
    axes[0, 0].axvline(peak_t, color='blue', linestyle=':', label=f'Peak = {peak_t:.2f}')
    axes[0, 0].text(mu_t_mean, axes[0, 0].get_ylim()[1]*0.9, f'{mu_t_mean:.2f}', color='red', ha='center')
    axes[0, 0].text(mu_t_median, axes[0, 0].get_ylim()[1]*0.75, f'{mu_t_median:.2f}', color='green', ha='center')
    axes[0, 0].text(peak_t, axes[0, 0].get_ylim()[1]*0.75, f'{peak_t:.2f}', color='blue', ha='center')
    axes[0, 0].legend()

    # --- Plot 2: Piecewise mu_t over repi_idx ---
    df = model.df
    reps = np.arange(1, 7)
    means, stds, num_subjs = [], [], []
    for rep in reps:
        RT = df.query("phase == @phase2model and group == @group2model and repi == @rep")[dataname].values
        means.append(np.mean(RT))
        stds.append(np.std(RT))
        num_subjs.append(len(RT))

    means = np.asarray(means)
    sems  = np.asarray(stds) / np.sqrt(np.asarray(num_subjs))
    order = np.argsort(np.searchsorted(repi_idx_all, reps))

    x_raw   = reps if np.array_equal(repi_idx_all, reps) else repi_idx_all
    y_raw   = means[order]
    err_raw = sems[order]

    axes[0, 1].errorbar(
        x_raw,
        y_raw,
        yerr=err_raw,
        fmt='s--',
        linewidth=1.5,
        capsize=4,
        ecolor='black',
        alpha=0.9,
        label=r'Raw mean RT $\pm$ SEM'
    )
    axes[0, 1].errorbar(
        repi_idx_all,
        mu_t_all,
        fmt='o-',
        color='darkred',
        ecolor='gray',
        elinewidth=2,
        capsize=4,
        label=r'$\mu_t$'
    )

    axes[0, 1].set_title("Piecewise $\mu_t$(repi\_idx)")
    axes[0, 1].set_xlabel("repi_idx")
    axes[0, 1].set_ylabel(r"$\mu_t$")
    axes[0, 1].grid(True)
    axes[0, 1].set_xticks(repi_idx_all)
    axes[0, 1].legend()

    # plot raw RT mean at axes[0, 2]
    subset = df[(df["phase"] == phase2model) & (df["group"] == group2model)]
    num_subjects = subset['name'].nunique()  # Count unique subjects

    for rep in range(1, 7):
        RT = df[(df["phase"] == phase2model) & (df["group"] == group2model) & (df["repi"] == rep)][dataname].values
        sns.ecdfplot(RT, label=f"RT{rep}", ax=axes[0, 2], color=palette[rep - 1], alpha = (1-0.8*rep/6))
    
    axes[0, 2].set_title(f"Raw RT CDFs by Repetition\nN subjects: {num_subjects}")
    axes[0, 2].set_xlabel("RT")
    axes[0, 2].set_ylabel("CDF")
    axes[0, 2].grid(True)
    axes[0, 2].legend(title="Repetition")

    # --- Plot 3: Posterior of a ---
    sns.histplot(a_samples, bins=nb, kde=True, stat="density", color='cornflowerblue', ax=axes[1, 0])
    axes[1, 0].set_title("Posterior Distribution of $a$")
    axes[1, 0].set_xlabel("$a$")
    axes[1, 0].set_ylabel("PDF")
    axes[1, 0].grid(True)
    axes[1, 0].axvline(a_mean, color='red', linestyle='--', label=f'Mean = {a_mean:.2f}')
    axes[1, 0].axvline(a_median, color='green', linestyle=':', label=f'Median = {a_median:.2f}')
    axes[1, 0].text(a_mean, axes[1, 0].get_ylim()[1]*0.9, f'{a_mean:.2f}', color='red', ha='center')
    axes[1, 0].text(a_median, axes[1, 0].get_ylim()[1]*0.75, f'{a_median:.2f}', color='green', ha='center')
    axes[1, 0].legend()


    # --- Plot 3: Posterior of b ---
    sns.histplot(b_samples, bins=nb, kde=True, stat="density", color='cornflowerblue', ax=axes[1, 1])
    axes[1, 1].set_title("Posterior Distribution of $b$")
    axes[1, 1].set_xlabel("$b$")
    axes[1, 1].set_ylabel("PDF")
    axes[1, 1].grid(True)
    axes[1, 1].axvline(b_mean, color='red', linestyle='--', label=f'Mean = {b_mean:.2f}')
    axes[1, 1].axvline(b_median, color='green', linestyle=':', label=f'Median = {b_median:.2f}')
    axes[1, 1].text(b_mean, axes[1, 1].get_ylim()[1]*0.9, f'{b_mean:.2f}', color='red', ha='center')
    axes[1, 1].text(b_median, axes[1, 1].get_ylim()[1]*0.75, f'{b_median:.2f}', color='green', ha='center')
    axes[1, 1].legend()

    # Layout adjustment
    plt.tight_layout()
    fig_dir = f"Figures/{output_folder}_{filename_prior_str}"
    os.makedirs(f"{fig_dir}", exist_ok=True)
    plt.savefig(f"{fig_dir}/Dist_group{group2model}_phase{phase2model}_sample{ns}_chain{nc}.png", dpi=300)


In [None]:
from joblib import Parallel, delayed
ns = 5000
nc = 4
# Just run the first N combos to test
N_COMBOS = 12
group2model = [1,5,7]
ot=1
for phase2model in [1, 2, 3]:
        for group2model in [1, 5, 7]:

            print(f"Launching jobs for Group {group2model}, Phase {phase2model}, OutlierType {ot}")

            Parallel(n_jobs=-1)(  # Use all available cores
                delayed(run_model_combo)(ot, phase2model, group2model, ns, nc,i, combo)
                for i, combo in enumerate(combinations[:N_COMBOS])  # Limit if needed
            )
