In [None]:
#%% [code]

# Comprehensive convergence diagnostics

print("\n" + "="*80)
print("CONVERGENCE DIAGNOSTICS SUMMARY")
print("="*80)

traces = {
    "hybrid": hybrid_trace,
    "mixture": mixture_trace,
    "mixture_3": mixture_trace_3,
    "hybrid_2": hybrid_trace_2,
    "mixture_4": mixture_trace_4
}

for name, trace in traces.items():
    print(f"\n{name.upper()} MODEL:")
    print("-" * 80)
    # Get summary which includes R-hat and ESS
    summary = az.summary(trace)
    
    # Extract max R-hat and min ESS from summary
    if 'r_hat' in summary.columns:
        max_rhat = float(summary['r_hat'].max())
        print(f"Max R-hat: {max_rhat:.4f} {'✓' if max_rhat < 1.01 else '✗ WARNING: R-hat > 1.01'}")
    else:
        # Fallback: compute directly
        rhat = az.rhat(trace)
        max_rhat = float(rhat.max().to_numpy())
        print(f"Max R-hat: {max_rhat:.4f} {'✓' if max_rhat < 1.01 else '✗ WARNING: R-hat > 1.01'}")
    
    if 'ess_bulk' in summary.columns:
        min_ess = float(summary['ess_bulk'].min())
        print(f"Min ESS (bulk): {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")
    elif 'ess_mean' in summary.columns:
        min_ess = float(summary['ess_mean'].min())
        print(f"Min ESS (mean): {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")
    else:
        # Fallback: compute directly
        ess = az.ess(trace)
        min_ess = float(ess.min().to_numpy())
        print(f"Min ESS: {min_ess:.0f} {'✓' if min_ess > 400 else '✗ WARNING: ESS < 400'}")

print("\n" + "="*80 + "\n")


CONVERGENCE DIAGNOSTICS SUMMARY

HYBRID MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 18272 ✓

MIXTURE MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 28219 ✓

MIXTURE_3 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 38359 ✓

HYBRID_2 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 25643 ✓

MIXTURE_4 MODEL:
--------------------------------------------------------------------------------
Max R-hat: 1.0000 ✓
Min ESS (bulk): 22954 ✓




In [None]:
#%% [code]

# Compare models

az.compare({
    "hybrid": hybrid_trace,
    "mixture": mixture_trace,
    "mixture_3": mixture_trace_3,
    "hybrid_2": hybrid_trace_2,
    "mixture_4": mixture_trace_4
})

  weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
  return umr_sum(a, axis, dtype, out, keepdims, initial, where)


Unnamed: 0,rank,elpd_loo,p_loo,elpd_diff,weight,se,dse,warning,scale
mixture_4,0,-326.951707,7.580393,0.0,0.6094116,6.4014,0.0,False,log
mixture_3,1,-328.861611,8.03129,1.909904,0.3905884,6.368723,7.213591,False,log
hybrid,2,-330.740022,4.656922,3.788315,0.0,7.415212,4.761528,False,log
hybrid_2,3,-331.409091,5.417014,4.457384,9.560185e-16,6.386802,4.18643,False,log
mixture,4,-334.799637,7.695265,7.84793,0.0,7.911812,8.030706,False,log
