In [5]:
import cmdstanpy
import numpy as np
import scipy as sp
import cmdstanpy as csp
from scipy.special import expit
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
import ot
from scipy.optimize import minimize
from scipy.stats import multivariate_normal

# Setup
seed = 583883
D = 6
N = 1500
mu_x_OH = -1
mu_x_NY = 1
B = 10_000
h = 0.04

rng = np.random.default_rng(seed)

# Generate true parameters and data
beta = rng.normal(loc=0.0, scale=1.0, size=D)
print(f"True beta: {beta}")

# Generate Ohio data
x_OH = rng.normal(loc=-1.0, scale=1.0, size=(N, D))
y_OH = rng.binomial(n=1, p=expit(x_OH @ beta))

# Generate NY data (same true parameters, different x distribution)
x_NY = rng.normal(loc=1.0, scale=1.0, size=(N, D))
y_NY = rng.binomial(n=1, p=expit(x_NY @ beta))
print(f"NY response rate: {np.mean(y_NY):.3f}")

# Fit Ohio data with flat prior
print("\n=== Fitting Ohio data with flat prior ===")
data_OH = {'N': N, 'D': D, 'x': x_OH, 'y': y_OH}
model_OH = csp.CmdStanModel(stan_file='flat-logistic.stan')
fit_OH = model_OH.sample(data=data_OH, chains=4, iter_sampling=B // 4, seed=seed)
print(fit_OH.summary())
beta_OH_draws = fit_OH.stan_variable('beta')

# === APPROACH 1: Zhong's KDE Method ===
print("\n=== APPROACH 1: Zhong's KDE Method ===")
data_NY_kde = {
    'N': N, 'D': D, 'x': x_NY, 'y': y_NY,
    'h': h, 'B': B, 'beta0': beta_OH_draws
}
model_NY_kde = csp.CmdStanModel(stan_file='empirical-logistic.stan')
fit_NY_kde = model_NY_kde.sample(
    data=data_NY_kde, chains=4, iter_warmup=500, iter_sampling=500,
    seed=seed, show_progress=False
)
print(fit_NY_kde.summary())

# === APPROACH 2: Wasserstein Moment Matching ===
print("\n=== APPROACH 2: Wasserstein Moment Matching ===")

def fit_wasserstein_prior_simple(posterior_draws):
    """Closed-form Wasserstein-optimal Gaussian: just match moments"""
    mu = np.mean(posterior_draws, axis=0)
    Sigma = np.cov(posterior_draws.T)
    return mu, Sigma

mu_prior, Sigma_prior = fit_wasserstein_prior_simple(beta_OH_draws)
print(f"Prior mean: {mu_prior}")
print(f"Prior covariance diagonal: {np.diag(Sigma_prior)}")

data_NY_wass = {
    'N': N, 'D': D, 'x': x_NY, 'y': y_NY,
    'mu_prior': mu_prior, 'Sigma_prior': Sigma_prior
}
model_NY_wass = csp.CmdStanModel(stan_file='gaussian-prior-logistic.stan')
fit_NY_wass = model_NY_wass.sample(
    data=data_NY_wass, chains=4, iter_warmup=500, iter_sampling=500,
    seed=seed, show_progress=False
)
print(fit_NY_wass.summary())

# === APPROACH 3: Wasserstein Barycenters ===
print("\n=== APPROACH 3: Wasserstein Barycenters ===")

def fit_wasserstein_barycenter_adaptive(posterior_draws, max_components=5, method='bic'):
    """
    Adaptive component selection with reasonable defaults
    """
    D = posterior_draws.shape[1]
    n_samples = len(posterior_draws)

    # Conservative upper bound based on sample size and dimension
    max_k = min(max_components, max(1, n_samples // (50 * D)))
    print(f"Considering up to {max_k} components")

    if method == 'bic':
        # Use BIC for model selection
        n_components_range = range(1, max_k + 1)
        bic_scores = []

        models = []
        for n_components in n_components_range:
            gmm = GaussianMixture(n_components=n_components, random_state=42)
            gmm.fit(posterior_draws)
            bic_scores.append(gmm.bic(posterior_draws))
            models.append(gmm)

        # Choose best by BIC
        best_k = n_components_range[np.argmin(bic_scores)]
        best_gmm = models[np.argmin(bic_scores)]

        print(f"BIC scores: {dict(zip(n_components_range, bic_scores))}")
        print(f"Selected {best_k} components")

        return best_gmm.weights_, best_gmm.means_, best_gmm.covariances_

    elif method == 'bayesian':
        # Use Bayesian Gaussian Mixture
        bgm = BayesianGaussianMixture(
            n_components=max_k,
            random_state=42,
            weight_concentration_prior=1.0
        )
        bgm.fit(posterior_draws)

        # Extract only components with significant weight
        significant_components = bgm.weights_ > 1e-3
        n_effective = np.sum(significant_components)

        weights = bgm.weights_[significant_components]
        weights = weights / np.sum(weights)  # Renormalize
        means = bgm.means_[significant_components]
        covariances = bgm.covariances_[significant_components]

        print(f"Bayesian selection: {n_effective} out of {max_k} components")
        print(f"Component weights: {weights}")

        return weights, means, covariances

# Fit barycenter approximation
weights_bary, means_bary, covariances_bary = fit_wasserstein_barycenter_adaptive(
    beta_OH_draws, max_components=5, method='bic'
)

# Prepare data for Stan with mixture prior
data_NY_bary = {
    'N': N, 'D': D, 'x': x_NY, 'y': y_NY,
    'K': len(weights_bary),
    'weights': weights_bary,
    'means': means_bary,
    'covariances': covariances_bary
}

model_NY_bary = csp.CmdStanModel(stan_file='mixture-prior-logistic.stan')
fit_NY_bary = model_NY_bary.sample(
    data=data_NY_bary, chains=4, iter_warmup=500, iter_sampling=500,
    seed=seed, show_progress=False
)
print(fit_NY_bary.summary())

# === COMPARISON ===
print("\n=== COMPARISON OF METHODS ===")

def compute_posterior_metrics(fit, true_beta, method_name):
    """Compute key metrics for each method"""
    draws = fit.stan_variable('beta')
    post_mean = np.mean(draws, axis=0)
    post_std = np.std(draws, axis=0)

    # Mean squared error
    mse = np.mean((post_mean - true_beta)**2)

    # Average posterior variance
    avg_var = np.mean(post_std**2)

    # Coverage of true parameter (95% interval)
    lower = np.percentile(draws, 2.5, axis=0)
    upper = np.percentile(draws, 97.5, axis=0)
    coverage = np.mean((true_beta >= lower) & (true_beta <= upper))

    print(f"\n{method_name}:")
    print(f"  MSE: {mse:.6f}")
    print(f"  Avg Posterior Variance: {avg_var:.6f}")
    print(f"  95% Coverage: {coverage:.3f}")
    print(f"  Posterior Mean: {post_mean}")

    return {'MSE': mse, 'Variance': avg_var, 'Coverage': coverage}

# Compare all methods
metrics_kde = compute_posterior_metrics(fit_NY_kde, beta, "KDE (Zhong's method)")
metrics_wass = compute_posterior_metrics(fit_NY_wass, beta, "Wasserstein Moments")
metrics_bary = compute_posterior_metrics(fit_NY_bary, beta, "Wasserstein Barycenters")

# Create reference with flat prior on NY data
print("\n=== REFERENCE: Flat prior on NY data ===")
data_NY_flat = {'N': N, 'D': D, 'x': x_NY, 'y': y_NY}
model_NY_flat = csp.CmdStanModel(stan_file='flat-logistic.stan')
fit_NY_flat = model_NY_flat.sample(
    data=data_NY_flat, chains=4, iter_warmup=500, iter_sampling=500,
    seed=seed, show_progress=False
)
metrics_flat = compute_posterior_metrics(fit_NY_flat, beta, "Flat Prior (no Ohio info)")

print(f"\nTrue beta: {beta}")

DEBUG:cmdstanpy:found newer exe file, not recompiling
DEBUG:cmdstanpy:cmd: /content/flat-logistic info
cwd: None
DEBUG:cmdstanpy:input tempfile: /tmp/tmp0qao1gqg/a8cm0lo9.json
19:44:44 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing


True beta: [ 0.32292202 -1.67712417  0.80797451  0.23766868  0.86741335 -1.506818  ]
NY response rate: 0.368

=== Fitting Ohio data with flat prior ===


chain 1 |          | 00:00 Status

chain 2 |          | 00:00 Status

chain 3 |          | 00:00 Status

chain 4 |          | 00:00 Status

DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/a8cm0lo9.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logisticvy18ks1a/flat-logistic-20250514194444_2.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1']
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/a8cm0lo9.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logisticvy18ks1a/flat-logistic-20250514194444_1.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1']
DEBUG:cmdstanpy:idx 2
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=3', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/a8cm0lo9.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logis

                                                                                                                                                                                                                                                                                                                                

19:44:59 - cmdstanpy - INFO - CmdStan done processing.
INFO:cmdstanpy:CmdStan done processing.
DEBUG:cmdstanpy:runset
RunSet: chains=4, chain_ids=[1, 2, 3, 4], num_processes=4
 cmd (chain 1):
	['/content/flat-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/a8cm0lo9.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logisticvy18ks1a/flat-logistic-20250514194444_1.csv', 'method=sample', 'num_samples=2500', 'algorithm=hmc', 'adapt', 'engaged=1']
 retcodes=[0, 0, 0, 0]
 per-chain output files (showing chain 1 only):
 csv_file:
	/tmp/tmp0qao1gqg/flat-logisticvy18ks1a/flat-logistic-20250514194444_1.csv
 console_msgs (if any):
	/tmp/tmp0qao1gqg/flat-logisticvy18ks1a/flat-logistic-20250514194444_0-stdout.txt
DEBUG:cmdstanpy:Chain 1 console:
method = sample (Default)
  sample
    num_samples = 2500
    num_warmup = 1000 (Default)
    save_warmup = false (Default)
    thin = 1 (Default)
    adapt
      engaged = true (Default)
      gamma = 0.05 (Default)
      delta = 0.8 (




DEBUG:cmdstanpy:found newer exe file, not recompiling
DEBUG:cmdstanpy:cmd: /content/empirical-logistic info
cwd: None
DEBUG:cmdstanpy:input tempfile: /tmp/tmp0qao1gqg/zz8nbas1.json
19:44:59 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/zz8nbas1.json', 'output', 'file=/tmp/tmp0qao1gqg/empirical-logisticcgl9nbam/empirical-logistic-20250514194459_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1']
19:44:59 - cmdstanpy - INFO - Chain [1] start processing
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/zz8nbas1.json', 'output', 'file=/tmp/tmp0qao1gqg/empirical-logis

               Mean      MCSE    StdDev       MAD          5%         50%  \
lp__    -557.406000  0.026795  1.770460  1.568590 -560.791000 -557.063000   
beta[1]    0.366866  0.000819  0.072424  0.071326    0.250861    0.366657   
beta[2]   -1.789300  0.001410  0.107540  0.107689   -1.969140   -1.785910   
beta[3]    0.917581  0.000989  0.079849  0.078577    0.787154    0.916908   
beta[4]    0.245058  0.000795  0.072760  0.073003    0.124279    0.245038   
beta[5]    0.826520  0.000930  0.078852  0.078082    0.699908    0.825228   
beta[6]   -1.578740  0.001312  0.098768  0.098519   -1.742420   -1.576840   

                95%  ESS_bulk  ESS_tail     R_hat  
lp__    -555.199000   4633.58   5881.42  0.999884  
beta[1]    0.487298   7869.64   6899.19  1.000120  
beta[2]   -1.618440   5936.46   6436.66  1.000000  
beta[3]    1.051410   6569.76   6402.72  1.000210  
beta[4]    0.364416   8444.90   6252.19  1.000350  
beta[5]    0.956443   7271.80   6691.55  1.000470  
beta[6]   -1.418180

19:44:59 - cmdstanpy - INFO - Chain [2] start processing
INFO:cmdstanpy:Chain [1] start processing
INFO:cmdstanpy:Chain [2] start processing
19:46:06 - cmdstanpy - INFO - Chain [2] done processing
INFO:cmdstanpy:Chain [2] done processing
DEBUG:cmdstanpy:idx 2
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=3', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/zz8nbas1.json', 'output', 'file=/tmp/tmp0qao1gqg/empirical-logisticcgl9nbam/empirical-logistic-20250514194459_3.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1']
19:46:06 - cmdstanpy - INFO - Chain [3] start processing
INFO:cmdstanpy:Chain [3] start processing
19:46:07 - cmdstanpy - INFO - Chain [1] done processing
INFO:cmdstanpy:Chain [1] done processing
DEBUG:cmdstanpy:idx 3
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/empirical-logistic', 'id=4', 'random', 'seed=583

               Mean      MCSE    StdDev       MAD          5%         50%  \
lp__    -559.231000  0.052517  1.670540  1.512990 -562.479000 -558.881000   
beta[1]    0.350303  0.001256  0.053163  0.055772    0.264625    0.349386   
beta[2]   -1.684520  0.002128  0.072740  0.073189   -1.806480   -1.683590   
beta[3]    0.816993  0.001509  0.058009  0.055932    0.720248    0.818284   
beta[4]    0.308557  0.001222  0.052301  0.054825    0.222648    0.310081   
beta[5]    0.807412  0.001536  0.056870  0.057270    0.714319    0.807694   
beta[6]   -1.547500  0.002062  0.069704  0.068882   -1.664110   -1.546650   

                95%  ESS_bulk  ESS_tail     R_hat  
lp__    -557.139000   1042.35   1121.73  1.007360  
beta[1]    0.440071   1791.38   1473.59  1.000620  
beta[2]   -1.568570   1172.88   1269.70  1.001860  
beta[3]    0.910145   1522.17   1148.50  0.999944  
beta[4]    0.389971   1861.81   1350.30  1.001190  
beta[5]    0.902555   1385.68   1368.33  1.002960  
beta[6]   -1.435650

DEBUG:cmdstanpy:Console output:

--- Translating Stan model to C++ code ---
bin/stanc --filename-in-msg=gaussian-prior-logistic.stan --o=/content/gaussian-prior-logistic.hpp /content/gaussian-prior-logistic.stan

--- Compiling C++ code ---
g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess      -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I stan/src -I stan/lib/rapidjson_1.1.0/ -I lib/CLI11-1.9.1/ -I stan/lib/stan_math/ -I stan/lib/stan_math/lib/eigen_3.4.0 -I stan/lib/stan_math/lib/boost_1.84.0 -I stan/lib/stan_math/lib/sundials_6.1.1/include -I stan/lib/stan_math/lib/sundials_6.1.1/src/sundials    -DBOOST_DISABLE_ASSERTS          -c -Wno-ignored-attributes   -x c++ -o /content/gaussian-prior-logistic.o /content/gaussian-prior-logistic.hpp

--- Linking model ---
g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess      -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I sr

               Mean      MCSE    StdDev       MAD          5%         50%  \
lp__    -568.833000  0.061184  1.775250  1.637530 -572.075000 -568.518000   
beta[1]    0.352418  0.001256  0.049502  0.048350    0.271975    0.351888   
beta[2]   -1.696680  0.002096  0.075393  0.072381   -1.818390   -1.698950   
beta[3]    0.824647  0.001493  0.057310  0.057457    0.732214    0.822308   
beta[4]    0.301689  0.001300  0.050560  0.050508    0.216780    0.301548   
beta[5]    0.808698  0.001430  0.056148  0.054268    0.714924    0.808041   
beta[6]   -1.550370  0.001959  0.071111  0.067636   -1.671050   -1.548620   

                95%  ESS_bulk  ESS_tail     R_hat  
lp__    -566.585000   881.262   1143.34  1.002360  
beta[1]    0.434444  1584.370   1544.86  1.001480  
beta[2]   -1.570150  1307.350   1407.69  0.999047  
beta[3]    0.922475  1502.040   1410.28  1.002310  
beta[4]    0.386601  1520.230   1296.52  1.000290  
beta[5]    0.901743  1563.190   1176.63  1.000660  
beta[6]   -1.436520

19:47:44 - cmdstanpy - INFO - compiling stan file /content/mixture-prior-logistic.stan to exe file /content/mixture-prior-logistic
INFO:cmdstanpy:compiling stan file /content/mixture-prior-logistic.stan to exe file /content/mixture-prior-logistic
DEBUG:cmdstanpy:cmd: make STANCFLAGS+=--filename-in-msg=mixture-prior-logistic.stan /content/mixture-prior-logistic
cwd: /root/.cmdstan/cmdstan-2.36.0


BIC scores: {1: np.float64(-144967.7349004295), 2: np.float64(-144778.18607570804), 3: np.float64(-144518.11879955523), 4: np.float64(-144302.60325127767), 5: np.float64(-144075.85674640306)}
Selected 1 components


DEBUG:cmdstanpy:Console output:

--- Translating Stan model to C++ code ---
bin/stanc --filename-in-msg=mixture-prior-logistic.stan --o=/content/mixture-prior-logistic.hpp /content/mixture-prior-logistic.stan

--- Compiling C++ code ---
g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess      -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I stan/src -I stan/lib/rapidjson_1.1.0/ -I lib/CLI11-1.9.1/ -I stan/lib/stan_math/ -I stan/lib/stan_math/lib/eigen_3.4.0 -I stan/lib/stan_math/lib/boost_1.84.0 -I stan/lib/stan_math/lib/sundials_6.1.1/include -I stan/lib/stan_math/lib/sundials_6.1.1/src/sundials    -DBOOST_DISABLE_ASSERTS          -c -Wno-ignored-attributes   -x c++ -o /content/mixture-prior-logistic.o /content/mixture-prior-logistic.hpp

--- Linking model ---
g++ -std=c++17 -pthread -D_REENTRANT -Wno-sign-compare -Wno-ignored-attributes -Wno-class-memaccess      -I stan/lib/stan_math/lib/tbb_2020.3/include    -O3 -I src -I 

               Mean      MCSE    StdDev       MAD          5%         50%  \
lp__    -558.570000  0.058193  1.764730  1.537460 -562.060000 -558.230000   
beta[1]    0.352823  0.001369  0.051899  0.051785    0.265846    0.351955   
beta[2]   -1.697660  0.002141  0.074776  0.071995   -1.819920   -1.698660   
beta[3]    0.824432  0.001377  0.055388  0.056361    0.734386    0.824173   
beta[4]    0.300913  0.001182  0.049481  0.049706    0.215984    0.300157   
beta[5]    0.809686  0.001453  0.056591  0.056076    0.719055    0.808768   
beta[6]   -1.547870  0.001895  0.066391  0.069326   -1.654070   -1.549730   

                95%  ESS_bulk  ESS_tail     R_hat  
lp__    -556.376000    922.66   1443.26  1.003070  
beta[1]    0.436480   1460.40   1280.90  0.999683  
beta[2]   -1.571360   1233.38   1251.29  1.001840  
beta[3]    0.915763   1615.54   1494.00  1.001040  
beta[4]    0.383092   1761.10   1363.84  1.000110  
beta[5]    0.900963   1530.90   1304.08  1.002850  
beta[6]   -1.438090

19:48:24 - cmdstanpy - INFO - CmdStan start processing
INFO:cmdstanpy:CmdStan start processing
DEBUG:cmdstanpy:idx 0
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=1', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/5a47ebg3.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logisticge1c6cjy/flat-logistic-20250514194824_1.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1']
19:48:24 - cmdstanpy - INFO - Chain [1] start processing
INFO:cmdstanpy:Chain [1] start processing
DEBUG:cmdstanpy:idx 1
DEBUG:cmdstanpy:running CmdStan, num_threads: 1
DEBUG:cmdstanpy:CmdStan args: ['/content/flat-logistic', 'id=2', 'random', 'seed=583883', 'data', 'file=/tmp/tmp0qao1gqg/5a47ebg3.json', 'output', 'file=/tmp/tmp0qao1gqg/flat-logisticge1c6cjy/flat-logistic-20250514194824_2.csv', 'method=sample', 'num_samples=500', 'num_warmup=500', 'algorithm=hmc', 'adapt', 'engaged=1']
19:48:24 - cmdstanpy -


Flat Prior (no Ohio info):
  MSE: 0.004266
  Avg Posterior Variance: 0.006697
  95% Coverage: 1.000
  Posterior Mean: [ 0.34536999 -1.62938658  0.74918455  0.35338832  0.7991306  -1.54288859]

True beta: [ 0.32292202 -1.67712417  0.80797451  0.23766868  0.86741335 -1.506818  ]
