 ## Preliminaries

In [None]:
#%% [code]

# Includes


import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import numpy as np
import arviz as az
import pytensor.tensor as at



np.set_printoptions(suppress=True,precision=3)

In [None]:
#%% [code]

# Load data

from load_data import load_data_for_experiments, load_data_for_experiment

experiment_version = "V1.0_pilot"
data = load_data_for_experiment(experiment_version)

#experiment_versions = ["V0.3_pilot", "V1.0_pilot", "V1.1_pilot"]
#data = load_data_for_experiments(experiment_versions)

counts = data['counts']

In [None]:
#%% [code]

# Prepare data

# SYNTHETIC DATA for model recoverability
synthetic_hybrid_counts = np.tile(np.array([10, 5, 5, 1, 1, 4]), (30, 1))

synthetic_mixture_counts = np.vstack([
    np.tile(np.array([20, 1, 1, 1, 1, 4]), (15, 1)),
    np.tile(np.array([2, 10, 10, 1, 1, 4]), (15, 1))
])

synthetic_mixture_3_counts = np.vstack([
    np.tile(np.array([20, 1, 1, 1, 1, 4]), (10, 1)),
    np.tile(np.array([2, 10, 10, 1, 1, 4]), (10, 1)),
    np.tile(np.array([1, 1, 1, 1, 20, 4]), (10, 1))
])

synthetic_hybrid_2_counts = np.tile(np.array([10, 5, 5, 1, 10, 4]), (30, 1))

synthetic_mixture_4_counts = np.vstack([
    np.tile(np.array([10, 5, 5, 1, 1, 4]), (15, 1)),
    np.tile(np.array([1, 1, 1, 1, 20, 4]), (15, 1))
])
#counts = synthetic_mixture_3_counts

S, K = counts.shape
N = counts.sum(axis=1)

In [None]:
#%% [code]

# Hybrid model: GPI zero + Policy reuse cued

with pm.Model() as hybrid_model:
    # positive raw weights for the tied pattern [n, m, m, e, e, 4e]
    u_e = pm.Uniform("u_e", lower=0.0, upper=1/9)
    u_n = pm.Uniform("u_n", lower=u_e, upper=1.0)
    u_m = pm.Uniform("u_m", lower=u_e, upper=1.0)

    # total concentration (how similar subjects are)
    c = pm.LogNormal("c", mu=0.0, sigma=1.5)

    # construct tied base proportions
    theta_raw = at.stack([u_n, u_m, u_m, u_e, u_e, 4*u_e])
    theta = theta_raw / theta_raw.sum()

    # Dirichlet parameters
    alpha = c * theta

    # vectorized Dirichlet–Multinomial over subjects
    pm.DirichletMultinomial("x", a=alpha, n=N, shape=(S, K), observed=counts)

    hybrid_trace = pm.sample(10000, tune=10000, target_accept=0.95, chains=8, idata_kwargs={"log_likelihood": True}, return_inferencedata=True)
    # Check convergence
    print("Hybrid model convergence diagnostics:")
    print(az.summary(hybrid_trace))
    # predictive accuracy (subject-level pointwise log-lik is handled internally)
    loo_hybrid = az.loo(hybrid_trace)     # or az.waic(idata)

Initializing NUTS using jitter+adapt_diag...
This usually happens when PyTensor is installed via pip. We recommend it be installed via conda/mamba/pixi instead.
Alternatively, you can use an experimental backend such as Numba or JAX that perform their own BLAS optimizations, by setting `pytensor.config.mode == 'NUMBA'` or passing `mode='NUMBA'` when compiling a PyTensor function.
For more options and details see https://pytensor.readthedocs.io/en/latest/troubleshooting.html#how-do-i-configure-test-my-blas-library
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [u_e, u_n, u_m, c]


Output()

Sampling 8 chains for 10_000 tune and 10_000 draw iterations (80_000 + 80_000 draws total) took 130 seconds.


Hybrid model convergence diagnostics:
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
u_e  0.084  0.021   0.045    0.111      0.000    0.000   21766.0   26452.0   
u_n  0.240  0.084   0.083    0.397      0.001    0.000   26249.0   23304.0   
u_m  0.262  0.083   0.100    0.415      0.001    0.000   25447.0   23355.0   
c    4.135  0.660   2.957    5.390      0.003    0.003   40713.0   45295.0   

     r_hat  
u_e    1.0  
u_n    1.0  
u_m    1.0  
c      1.0  


In [None]:
#%% [code]

# Mixture:
#   1) GPI zero
#   2) Policy reuse cued

with pm.Model() as mixture_model:
    # mixture weight
    w = pm.Dirichlet('w', a=np.ones(2))
    # raw weights (positive)
    u_e1  = pm.Uniform('u_e1', lower=0.0, upper=1/9)
    u_e2  = pm.Uniform('u_e2', lower=0.0, upper=1/9)
    u_n   = pm.Uniform('u_n', lower=u_e1, upper=1.0)
    u_m   = pm.Uniform('u_m', lower=u_e2, upper=1.0)

    # concentration
    c1 = pm.LogNormal('c1', 0.0, 1.5)
    c2 = pm.LogNormal('c2', 0.0, 1.5)

    # component base measures
    theta1_raw = pm.math.stack([u_n, u_e1, u_e1, u_e1, u_e1, 4*u_e1])
    theta2_raw = pm.math.stack([u_e2, u_m,  u_m,  u_e2, u_e2, 4*u_e2])
    theta1 = theta1_raw / pm.math.sum(theta1_raw)
    theta2 = theta2_raw / pm.math.sum(theta2_raw)

    alpha1 = c1 * theta1
    alpha2 = c2 * theta2

    # subject-level mixture likelihood (marginal over z)
    # PyMC has DirichletMultinomial: pm.DirichletMultinomial
    like1 = pm.DirichletMultinomial.dist(a=alpha1, n=counts.sum(axis=1))
    like2 = pm.DirichletMultinomial.dist(a=alpha2, n=counts.sum(axis=1))

    # mixture across subjects
    pm.Mixture('x', w, comp_dists=[like1, like2], observed=counts)

    mixture_trace = pm.sample(10000, tune=10000, target_accept=0.95, chains=8, idata_kwargs={"log_likelihood": True}, return_inferencedata=True)
    # Check convergence
    print("Mixture model convergence diagnostics:")
    print(az.summary(mixture_trace))
    loo_mixture = az.loo(mixture_trace)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [w, u_e1, u_e2, u_n, u_m, c1, c2]


Output()

Sampling 8 chains for 10_000 tune and 10_000 draw iterations (80_000 + 80_000 draws total) took 327 seconds.


Mixture model convergence diagnostics:
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
w[0]   0.570  0.115   0.354    0.784      0.001    0.001   53889.0   36131.0   
w[1]   0.430  0.115   0.216    0.646      0.001    0.001   53889.0   36131.0   
u_e1   0.075  0.026   0.028    0.111      0.000    0.000   36006.0   36456.0   
u_e2   0.066  0.026   0.024    0.111      0.000    0.000   36294.0   36603.0   
u_n    0.250  0.119   0.039    0.454      0.001    0.001   36156.0   32610.0   
u_m    0.619  0.235   0.219    1.000      0.001    0.001   32821.0   34261.0   
c1     5.221  1.227   3.071    7.533      0.005    0.006   65943.0   49600.0   
c2    14.190  6.475   4.715   25.381      0.030    0.095   52476.0   49734.0   

      r_hat  
w[0]    1.0  
w[1]    1.0  
u_e1    1.0  
u_e2    1.0  
u_n     1.0  
u_m     1.0  
c1      1.0  
c2      1.0  


In [None]:
#%% [code]

# Mixture:
#   1) GPI zero
#   2) Policy reuse cued
#   3) Model-based / GPI

with pm.Model() as mixture_model_3:
    # mixture weight
    w = pm.Dirichlet('w', a=np.ones(3))
    
    # raw weights (positive)
    u_e1  = pm.Uniform('u_e1', lower=0.0, upper=1/9)
    u_e2  = pm.Uniform('u_e2', lower=0.0, upper=1/9)
    u_e3  = pm.Uniform('u_e3', lower=0.0, upper=1/9)
    u_n   = pm.Uniform('u_n', lower=u_e1, upper=1.0)
    u_m   = pm.Uniform('u_m', lower=u_e2, upper=1.0)
    u_o   = pm.Uniform('u_o', lower=u_e3, upper=1.0)

    # concentration
    c1 = pm.LogNormal('c1', 0.0, 1.5)
    c2 = pm.LogNormal('c2', 0.0, 1.5)
    c3 = pm.LogNormal('c3', 0.0, 1.5)

    # component base measures
    theta1_raw = pm.math.stack([u_n, u_e1, u_e1, u_e1, u_e1, 4*u_e1])
    theta2_raw = pm.math.stack([u_e2, u_m,  u_m,  u_e2, u_e2, 4*u_e2])
    theta3_raw = pm.math.stack([u_e3, u_e3,  u_e3,  u_e3, u_o, 4*u_e3])
    theta1 = theta1_raw / pm.math.sum(theta1_raw)
    theta2 = theta2_raw / pm.math.sum(theta2_raw)
    theta3 = theta3_raw / pm.math.sum(theta3_raw)
    
    alpha1 = c1 * theta1
    alpha2 = c2 * theta2
    alpha3 = c3 * theta3

    # subject-level mixture likelihood (marginal over z)
    # PyMC has DirichletMultinomial: pm.DirichletMultinomial
    like1 = pm.DirichletMultinomial.dist(a=alpha1, n=counts.sum(axis=1))
    like2 = pm.DirichletMultinomial.dist(a=alpha2, n=counts.sum(axis=1))
    like3 = pm.DirichletMultinomial.dist(a=alpha3, n=counts.sum(axis=1))

    # mixture across subjects
    pm.Mixture('x', w, comp_dists=[like1, like2, like3], observed=counts)

    mixture_trace_3 = pm.sample(10000, tune=10000, target_accept=0.95, chains=8, idata_kwargs={"log_likelihood": True}, return_inferencedata=True)
    # Check convergence
    print("Mixture model 3 convergence diagnostics:")
    print(az.summary(mixture_trace_3))
    loo_mixture_3 = az.loo(mixture_trace_3)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [w, u_e1, u_e2, u_e3, u_n, u_m, u_o, c1, c2, c3]


Output()

Sampling 8 chains for 10_000 tune and 10_000 draw iterations (80_000 + 80_000 draws total) took 514 seconds.


Mixture model 3 convergence diagnostics:
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
w[0]   0.300  0.124   0.085    0.543      0.001    0.001   19100.0   15269.0   
w[1]   0.405  0.098   0.223    0.589      0.000    0.000   74913.0   58089.0   
w[2]   0.295  0.126   0.034    0.514      0.001    0.001   16898.0    9704.0   
u_e1   0.072  0.026   0.026    0.111      0.000    0.000   47252.0   40820.0   
u_e2   0.066  0.025   0.023    0.111      0.000    0.000   38571.0   35526.0   
u_e3   0.074  0.026   0.026    0.111      0.000    0.000   48335.0   44590.0   
u_n    0.491  0.236   0.093    0.933      0.001    0.001   24107.0   35652.0   
u_m    0.640  0.232   0.236    1.000      0.001    0.001   38862.0   40833.0   
u_o    0.267  0.156   0.017    0.529      0.001    0.002   24776.0   13423.0   
c1     4.950  2.056   1.527    8.746      0.009    0.026   46160.0   36850.0   
c2    15.338  6.221   5.784   26.325      0.026    0.046   69079.0   48221.0   

In [None]:
#%% [code]

# Hybrid model 2: GPI zero + Policy reuse cued + MB/GPI

with pm.Model() as hybrid_model_2:
    # positive raw weights for the tied pattern [n, m, m, e, o, 4e]
    u_e = pm.Uniform("u_e", lower=0.0, upper=1/9)
    u_n = pm.Uniform("u_n", lower=u_e, upper=1.0)
    u_m = pm.Uniform("u_m", lower=u_e, upper=1.0)
    u_o = pm.Uniform("u_o", lower=u_e, upper=1.0)

    # total concentration (how similar subjects are)
    c = pm.LogNormal("c", mu=0.0, sigma=1.5)

    # construct tied base proportions
    theta_raw = at.stack([u_n, u_m, u_m, u_e, u_o, 4*u_e])
    theta = theta_raw / theta_raw.sum()

    # Dirichlet parameters
    alpha = c * theta

    # vectorized Dirichlet–Multinomial over subjects
    pm.DirichletMultinomial("x", a=alpha, n=N, shape=(S, K), observed=counts)

    hybrid_trace_2 = pm.sample(10000, tune=10000, target_accept=0.95, chains=8, idata_kwargs={"log_likelihood": True}, return_inferencedata=True)
    # Check convergence
    print("Hybrid model 2 convergence diagnostics:")
    print(az.summary(hybrid_trace_2))
    # predictive accuracy (subject-level pointwise log-lik is handled internally)
    loo_hybrid_2 = az.loo(hybrid_trace_2)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [u_e, u_n, u_m, u_o, c]


Output()

Sampling 8 chains for 10_000 tune and 10_000 draw iterations (80_000 + 80_000 draws total) took 100 seconds.


Hybrid model 2 convergence diagnostics:
      mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
u_e  0.090  0.018   0.057    0.111      0.000    0.000   24281.0   31432.0   
u_n  0.363  0.119   0.149    0.592      0.001    0.001   27472.0   29556.0   
u_m  0.396  0.118   0.176    0.619      0.001    0.001   24969.0   28277.0   
u_o  0.318  0.106   0.130    0.523      0.001    0.000   27589.0   28140.0   
c    4.175  0.679   2.948    5.453      0.003    0.003   48446.0   50787.0   

     r_hat  
u_e    1.0  
u_n    1.0  
u_m    1.0  
u_o    1.0  
c      1.0  


In [None]:
#%% [code]

# Mixture 4:
#   1) GPI zero + Policy reuse cued
#   2) MB/GPI

with pm.Model() as mixture_model_4:
    # mixture weight
    w = pm.Dirichlet('w', a=np.ones(2))
    
    # raw weights (positive)
    u_e1  = pm.Uniform('u_e1', lower=0.0, upper=1/9)
    u_e2  = pm.Uniform('u_e2', lower=0.0, upper=1/9)
    u_n   = pm.Uniform('u_n', lower=u_e1, upper=1.0)
    u_m   = pm.Uniform('u_m', lower=u_e1, upper=1.0)
    u_o   = pm.Uniform('u_o', lower=u_e2, upper=1.0)

    # concentration
    c1 = pm.LogNormal('c1', 0.0, 1.5)
    c2 = pm.LogNormal('c2', 0.0, 1.5)

    # component base measures
    # Component 1: GPI zero + Policy reuse cued [n, m, m, e, e, 4e]
    theta1_raw = pm.math.stack([u_n, u_m, u_m, u_e1, u_e1, 4*u_e1])
    # Component 2: MB/GPI [e, e, e, e, o, 4e]
    theta2_raw = pm.math.stack([u_e2, u_e2, u_e2, u_e2, u_o, 4*u_e2])
    theta1 = theta1_raw / pm.math.sum(theta1_raw)
    theta2 = theta2_raw / pm.math.sum(theta2_raw)
    
    alpha1 = c1 * theta1
    alpha2 = c2 * theta2

    # subject-level mixture likelihood (marginal over z)
    # PyMC has DirichletMultinomial: pm.DirichletMultinomial
    like1 = pm.DirichletMultinomial.dist(a=alpha1, n=counts.sum(axis=1))
    like2 = pm.DirichletMultinomial.dist(a=alpha2, n=counts.sum(axis=1))

    # mixture across subjects
    pm.Mixture('x', w, comp_dists=[like1, like2], observed=counts)

    mixture_trace_4 = pm.sample(10000, tune=10000, target_accept=0.95, chains=8, idata_kwargs={"log_likelihood": True}, return_inferencedata=True)
    # Check convergence
    print("Mixture model 4 convergence diagnostics:")
    print(az.summary(mixture_trace_4))
    loo_mixture_4 = az.loo(mixture_trace_4)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (8 chains in 4 jobs)
NUTS: [w, u_e1, u_e2, u_n, u_m, u_o, c1, c2]


Output()

Sampling 8 chains for 10_000 tune and 10_000 draw iterations (80_000 + 80_000 draws total) took 305 seconds.


Mixture model 4 convergence diagnostics:
        mean     sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk  ess_tail  \
w[0]   0.445  0.126   0.216    0.668      0.001    0.001   27525.0   12929.0   
w[1]   0.555  0.126   0.332    0.784      0.001    0.001   27525.0   12929.0   
u_e1   0.071  0.023   0.034    0.111      0.000    0.000   34610.0   32053.0   
u_e2   0.075  0.026   0.028    0.111      0.000    0.000   45080.0   41325.0   
u_n    0.180  0.082   0.043    0.333      0.000    0.000   40541.0   42295.0   
u_m    0.712  0.203   0.345    1.000      0.001    0.001   31173.0   28226.0   
u_o    0.184  0.093   0.030    0.334      0.001    0.002   36663.0   28278.0   
c1    14.132  7.053   3.906   26.232      0.034    0.073   32448.0   16663.0   
c2     4.361  1.127   2.433    6.544      0.007    0.010   32753.0   14913.0   

      r_hat  
w[0]    1.0  
w[1]    1.0  
u_e1    1.0  
u_e2    1.0  
u_n     1.0  
u_m     1.0  
u_o     1.0  
c1      1.0  
c2      1.0  


In [None]:
#%% [code]

# Comprehensive convergence diagnostics

print("\n" + "="*80)
print("CONVERGENCE DIAGNOSTICS SUMMARY")
print("="*80)

traces = {
    "hybrid": hybrid_trace,
    "mixture": mixture_trace,
    "mixture_3": mixture_trace_3,
    "hybrid_2": hybrid_trace_2,
    "mixture_4": mixture_trace_4
}

for name, trace in traces.items():
    print(f"\n{name.upper()} MODEL:")
    print("-" * 80)
    # Get summary which includes R-hat and ESS
    summary = az.summary(trace)
    
    # Extract max R-hat and min ESS from summary
    if 'r_hat' in summary.columns:
        max_rhat = float(summary['r_hat'].max())
        print(f"Max R-hat: {max_rhat:.4f} {'✓' if max_rhat < 1.01 else '✗ WARNING: R-hat > 1.01'}")
    else:
        # Fallback: compute directly
        rhat = az.rhat(trace)
        max_rhat = float(rhat.max().to_numpy())
        print(f"Max R-hat: {max_rhat:.4f} {'✓' if max_rhat < 1.01 else '✗ WARNING: R-hat > 1.01'}")
    
    if 'ess_bulk' in summary.columns:
        min_ess = float(summary['ess_bulk'].min())
        print(f"Min ESS (bulk): {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")
    elif 'ess_mean' in summary.columns:
        min_ess = float(summary['ess_mean'].min())
        print(f"Min ESS (mean): {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")
    else:
        # Fallback: compute directly
        ess = az.ess(trace)
        min_ess = float(ess.min().to_numpy())
        print(f"Min ESS: {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")

print("\n" + "="*80 + "\n")


CONVERGENCE DIAGNOSTICS SUMMARY

HYBRID MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 21766 ✓

MIXTURE MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 32821 ✓

MIXTURE_3 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 16207 ✓

HYBRID_2 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 24281 ✓

MIXTURE_4 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 27525 ✓




In [None]:
#%% [code]

# Compare models

az.compare({
    "hybrid": hybrid_trace,
    "mixture": mixture_trace,
    "mixture_3": mixture_trace_3,
    "hybrid_2": hybrid_trace_2,
    "mixture_4": mixture_trace_4
})

Unnamed: 0,rank,elpd_loo,p_loo,elpd_diff,weight,se,dse,warning,scale
mixture_3,0,-301.518286,7.781459,0.0,0.4467214,8.23438,0.0,False,log
mixture,1,-304.547087,6.486995,3.028801,0.2191483,9.482639,3.329168,False,log
hybrid_2,2,-306.716681,5.291372,5.198395,0.3341303,5.122518,5.77664,False,log
mixture_4,3,-308.60336,7.837388,7.085074,2.018339e-14,7.26717,4.132043,False,log
hybrid,4,-313.903813,4.542803,12.385527,0.0,6.984115,5.471624,False,log
