In [None]:
# =========================
# Location-LOGISTIC (scale=1 known)
# Same architecture as your Student-location code
# =========================

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import random, vmap, jit
from jax.scipy.stats import norm, truncnorm
from jax.scipy.special import logsumexp
from jax.nn import softplus

EPS_Z   = 1e-12
EPS_U   = 1e-12
EPS_DIV = 1e-12

# -------------------------
# Logistic logpdf (stable)
# f(x|loc,s)=exp(-(x-loc)/s)/(s(1+exp(-(x-loc)/s))^2)
# log f = -t - log s - 2*softplus(-t), t=(x-loc)/s
# -------------------------
@jit
def logistic_logpdf(y: jnp.ndarray, loc: jnp.ndarray, scale: jnp.ndarray) -> jnp.ndarray:
    t = (y - loc) / scale
    return -t - jnp.log(scale) - 2.0 * softplus(-t)

# -------------------------
# Score transform for location-logistic (scale=1):
# score in mu is (1/s) * tanh( (x-mu)/(2s) )
# For the constrained manifold at mu_star, we use psi(y)=tanh(y/2), y = x - mu_star.
# Range: (-1,1), injective.
# -------------------------
@jit
def z_support_logistic():
    return (-1.0 + EPS_Z, 1.0 - EPS_Z)

@jit
def psi_logistic(y: jnp.ndarray) -> jnp.ndarray:
    return jnp.tanh(y / 2.0)

@jit
def sum_psi_logistic(y: jnp.ndarray) -> jnp.ndarray:
    return jnp.sum(psi_logistic(y))

@jit
def psi_inverse_logistic(z: jnp.ndarray):
    # unique inverse: y = 2 * atanh(z)
    # z_min, z_max = z_support_logistic()
    # z = jnp.clip(z, z_min, z_max)
    y = 2.0 * jnp.arctanh(z)
    return y  # single branch

@jit
def log_psi_prime_abs_logistic(y: jnp.ndarray) -> jnp.ndarray:
    # psi'(y) = (1/2) * sech^2(y/2) = (1/2)*(1 - tanh^2(y/2)) = (1/2)*(1 - psi(y)^2)
    z = psi_logistic(y)
    return jnp.log(0.5) + jnp.log(jnp.maximum(1.0 - z*z, 1e-30))

# -------------------------
# f_Y under current mu, expressed in y = x - mu_star
# If X ~ Logistic(loc=mu_current, scale=1), then Y = X - mu_star ~ Logistic(loc=mu_current-mu_star, scale=1).
# -------------------------
@jit
def fy_logpdf_logistic(y: jnp.ndarray, mu_current: jnp.ndarray, mu_star: jnp.ndarray) -> jnp.ndarray:
    loc = mu_current - mu_star
    return logistic_logpdf(y, loc=loc, scale=1.0)

# -------------------------
# q(z) = f_Y(y(z)) * |dy/dz|
# In your Student code you compute q via log f(y) - log|psi'(y)|, i.e. f(y)/|psi'(y)|
# because q(z)=f(y(z))*|dy/dz| and dy/dz = 1/psi'(y).
# Here injective => single term.
# -------------------------
@jit
def q_logpdf_logistic(z: jnp.ndarray, mu_current: jnp.ndarray, mu_star: jnp.ndarray) -> jnp.ndarray:
    z_min, z_max = z_support_logistic()
    in_supp = (z > z_min) & (z < z_max)
    y = psi_inverse_logistic(z)
    log_q = fy_logpdf_logistic(y, mu_current, mu_star) - log_psi_prime_abs_logistic(y)
    return jnp.where(in_supp, log_q, -jnp.inf)

@jit
def q_tilde_logpdf_logistic(z: jnp.ndarray, delta: jnp.ndarray, mu_current: jnp.ndarray, mu_star: jnp.ndarray) -> jnp.ndarray:
    return q_logpdf_logistic(z, mu_current, mu_star) + q_logpdf_logistic(delta - z, mu_current, mu_star)

# -------------------------
# Same z-update as you have, but with logistic q_tilde
# -------------------------
def update_z_one_logistic(
    key: jax.random.PRNGKey,
    z_current: jnp.ndarray,
    delta: jnp.ndarray,
    mu_current: jnp.ndarray,
    mu_star: jnp.ndarray,
    sigma_z: float
):
    key_prop, key_u = random.split(key, 2)

    low, high = z_support_logistic()

    # partner must also lie in (low, high): delta - z in (low, high)
    low2  = delta - high
    high2 = delta - low

    low_int  = jnp.maximum(low,  low2)
    high_int = jnp.minimum(high, high2)

    valid = low_int < high_int

    def do_reject(_):
        return z_current, False

    def do_update(_):
        a = (low_int  - z_current) / sigma_z
        b = (high_int - z_current) / sigma_z

        z_prop = z_current + sigma_z * random.truncated_normal(
            key_prop, shape=(), lower=a, upper=b
        )

        log_k_cur_to_prop = truncnorm.logpdf(z_prop, a=a, b=b, loc=z_current, scale=sigma_z)

        a_back = (low_int  - z_prop) / sigma_z
        b_back = (high_int - z_prop) / sigma_z
        log_k_prop_to_cur = truncnorm.logpdf(z_current, a=a_back, b=b_back, loc=z_prop, scale=sigma_z)

        log_post_cur  = q_tilde_logpdf_logistic(z_current, delta, mu_current, mu_star)
        log_post_prop = q_tilde_logpdf_logistic(z_prop,     delta, mu_current, mu_star)

        log_alpha = log_post_prop - log_post_cur + log_k_prop_to_cur - log_k_cur_to_prop
        log_alpha = jnp.where(jnp.isfinite(log_alpha), log_alpha, -jnp.inf)

        u = random.uniform(key_u, minval=EPS_U, maxval=1.0)
        accept = jnp.log(u) < log_alpha

        z_new = jnp.where(accept, z_prop, z_current)
        return z_new, accept

    return jax.lax.cond(valid, do_update, do_reject, operand=None)

# -------------------------
# Pair update: now inverse is unique, no branch sampling needed.
# -------------------------
def update_xi_xj_one_logistic(key, xi, xj, mu_current, mu_star, sigma_z):
    key_z = key

    yi, yj = xi - mu_star, xj - mu_star
    zi, zj = psi_logistic(yi), psi_logistic(yj)
    delta  = zi + zj

    zi_tilde, z_accepted = update_z_one_logistic(
        key_z, zi, delta, mu_current, mu_star, sigma_z
    )
    zj_tilde = delta - zi_tilde

    z_min, z_max = z_support_logistic()
    in_supp_partner = (zj_tilde > z_min) & (zj_tilde < z_max)

    def reject_pair(_):
        return xi, xj, False, z_accepted

    def accept_pair(_):
        yi_tilde = psi_inverse_logistic(zi_tilde)
        yj_tilde = psi_inverse_logistic(zj_tilde)
        xi_tilde = yi_tilde + mu_star
        xj_tilde = yj_tilde + mu_star
        return xi_tilde, xj_tilde, True, z_accepted

    return jax.lax.cond(in_supp_partner, accept_pair, reject_pair, operand=None)

# -------------------------
# Full x-update (same structure)
# -------------------------
@jit
def delta_from_xi_xj_logistic(xi, xj, mu_star):
    yi, yj = xi - mu_star, xj - mu_star
    zi, zj = psi_logistic(yi), psi_logistic(yj)
    return zi + zj

delta_from_xi_xj_batch_logistic = vmap(delta_from_xi_xj_logistic, in_axes=(0, 0, None))

@jit
def update_x_full_jax_logistic(key, x_current, mu_current, mu_star, sigma_z):
    m = x_current.shape[0]
    assert m % 2 == 0

    key_perm, key_pairs = random.split(key)
    perm = random.permutation(key_perm, m)
    x_perm = x_current[perm]

    xis = x_perm[0::2]
    xjs = x_perm[1::2]
    n_pairs = xis.shape[0]
    keys_pairs = random.split(key_pairs, n_pairs)

    update_xi_xj_batch = vmap(
        update_xi_xj_one_logistic,
        in_axes=(0, 0, 0, None, None, None)
    )

    xis_new, xjs_new, pair_accepted_vec, z_accepted_vec = update_xi_xj_batch(
        keys_pairs, xis, xjs, mu_current, mu_star, sigma_z
    )

    deltas = delta_from_xi_xj_batch_logistic(xis, xjs, mu_star)
    deltas_new = delta_from_xi_xj_batch_logistic(xis_new, xjs_new, mu_star)

    x_updated_pairs = jnp.stack([xis_new, xjs_new], axis=1).reshape(-1)
    x_perm_new = x_perm.at[0:m].set(x_updated_pairs)
    x_new = x_perm_new[jnp.argsort(perm)]

    pair_accepted_count = jnp.sum(pair_accepted_vec)
    z_accepted_count    = jnp.sum(z_accepted_vec)

    return x_new, pair_accepted_count, z_accepted_count, deltas, deltas_new

# -------------------------
# Posterior for mu given x (logistic likelihood + normal prior)
# -------------------------
def unnormalized_posterior_mu_logpdf_logistic(mu: float, x: jnp.ndarray, prior_loc: float, prior_scale: float) -> float:
    x = jnp.asarray(x)
    mu = jnp.asarray(mu)

    if mu.ndim == 0:
        log_likelihood = jnp.sum(logistic_logpdf(x, loc=mu, scale=1.0))
        log_prior = norm.logpdf(mu, loc=prior_loc, scale=prior_scale)
    else:
        log_likelihood = jnp.sum(logistic_logpdf(x[:, None], loc=mu[None, :], scale=1.0), axis=0)
        log_prior = norm.logpdf(mu, loc=prior_loc, scale=prior_scale)

    return log_likelihood + log_prior

@jit
def update_mu_metropolis_jax_logistic(key, mu_current, x_current, sigma_mu, prior_loc, prior_scale):
    key_prop, key_u = random.split(key)
    mu_candidate = mu_current + sigma_mu * random.normal(key_prop)

    log_post_current = unnormalized_posterior_mu_logpdf_logistic(mu_current, x_current, prior_loc, prior_scale)
    log_post_cand    = unnormalized_posterior_mu_logpdf_logistic(mu_candidate, x_current, prior_loc, prior_scale)

    log_alpha = log_post_cand - log_post_current
    log_alpha = jnp.where(jnp.isfinite(log_alpha), log_alpha, -jnp.inf)

    u = random.uniform(key_u, minval=EPS_U, maxval=1.0)
    accept = jnp.log(u) < log_alpha
    mu_new = jnp.where(accept, mu_candidate, mu_current)
    return mu_new, accept

# -------------------------
# Gibbs sampler (mu | x) then (x | mu, MLE=mu_star)
# NOTE: x0 must satisfy sum psi(x0-mu_star)=0. Setting all x=mu_star works (psi(0)=0).
# -------------------------
from tqdm import tqdm

def run_gibbs_sampler_mle_jax_logistic(key, mu_star: float, params: dict) -> dict:
    T = params["num_iterations_T"]
    m = params["m"]

    mus = jnp.zeros(T + 1)
    xs  = jnp.zeros((T + 1, m))

    x0 = jnp.ones(m) * mu_star
    xs  = xs.at[0, :].set(x0)
    mus = mus.at[0].set(mu_star)

    mu_acceptance_count = 0
    z_i_acceptance_count = 0
    pair_acceptance_count = 0
    grad_likelihood_checks = jnp.zeros(T)

    total_z_moves = T * (m // 2)

    for t in tqdm(range(1, T + 1), desc="Running Gibbs Sampler (Logistic-loc)"):
        key, key_mu, key_x = random.split(key, 3)

        # Step (a): mu | x
        x_current = xs[t - 1]
        mu_new, acc_mu = update_mu_metropolis_jax_logistic(
            key_mu,
            mus[t - 1],
            x_current,
            params["proposal_std_mu"],
            params["prior_mean"],
            params["prior_std"],
        )
        mus = mus.at[t].set(mu_new)
        mu_acceptance_count += acc_mu.astype(jnp.int32)

        # Step (b): x | mu, MLE=mu_star
        x_new, accepted_pairs, accepted_z_is, deltas, deltas_new = update_x_full_jax_logistic(
            key_x, x_current, mus[t], mu_star, params["proposal_std_z"]
        )
        xs = xs.at[t, :].set(x_new)

        grad_likelihood_checks = grad_likelihood_checks.at[t - 1].set(
            sum_psi_logistic(x_new - mu_star)
        )
        z_i_acceptance_count += accepted_z_is
        pair_acceptance_count += accepted_pairs

    mu_acceptance_rate   = mu_acceptance_count / T
    z_i_acceptance_rate  = z_i_acceptance_count / total_z_moves
    pair_acceptance_rate = pair_acceptance_count / total_z_moves

    print("\n--- Sampling Complete (Logistic-loc) ---")
    print(f"Mu Acceptance Rate:  {float(mu_acceptance_rate):.4f}")
    print(f"Z_i Acceptance Rate: {float(z_i_acceptance_rate):.4f}")
    print(f"Pair Acceptance Rate:{float(pair_acceptance_rate):.4f}")
    print(f"Score constraint check (last): {float(grad_likelihood_checks[-1]):.4e}")

    return {
        "mu_acceptance_rate": mu_acceptance_rate,
        "pair_acceptance_rate": pair_acceptance_rate,
        "z_i_acceptance_rate": z_i_acceptance_rate,
        "mu_chain": mus,
        "x_chain": xs,
        "grad_likelihood_checks": grad_likelihood_checks,
    }

In [None]:
# =========================
# TEST + RUN SCRIPT (Location-Logistic, scale=1 known)
# Assumes you already pasted/defined the logistic functions from my previous message:
#   - run_gibbs_sampler_mle_jax_logistic
#   - (and all helpers it calls)
# =========================

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import random
import matplotlib.pyplot as plt

# -------------------------
# Helper: simple summaries
# -------------------------
def summarize_chain(name, chain, burn=0, thin=1):
    chain = jnp.asarray(chain)
    chain = chain[burn::thin]
    return {
        "name": name,
        "n": int(chain.shape[0]),
        "mean": float(jnp.mean(chain)),
        "std": float(jnp.std(chain)),
        "q05": float(jnp.quantile(chain, 0.05)),
        "q50": float(jnp.quantile(chain, 0.50)),
        "q95": float(jnp.quantile(chain, 0.95)),
    }

def quick_plot_mu(mu_chain, burn=0, thin=1, title="mu chain"):
    mu = jnp.asarray(mu_chain)[burn::thin]
    plt.figure()
    plt.plot(mu)
    plt.title(title)
    plt.xlabel("iteration (post burn/thin)")
    plt.ylabel("mu")
    plt.show()

def quick_plot_constraint(checks, title="score constraint check"):
    c = jnp.asarray(checks)
    plt.figure()
    plt.plot(c)
    plt.title(title)
    plt.xlabel("iteration")
    plt.ylabel("sum_i tanh((x_i-mu*)/2)")
    plt.show()

def quick_hist_mu(mu_chain, burn=0, thin=1, title="mu posterior (approx)"):
    mu = jnp.asarray(mu_chain)[burn::thin]
    plt.figure()
    plt.hist(mu, bins=40, density=True)
    plt.title(title)
    plt.xlabel("mu")
    plt.ylabel("density")
    plt.show()

# -------------------------
# Main test runner
# -------------------------
def main():
    # ----- RNG
    seed = 0
    key = random.PRNGKey(seed)

    # ----- MLE constraint value
    mu_star = 0.0

    # ----- Sampler params
    # Tune proposal_std_z first: acceptance ~ 0.2-0.6 is often OK.
    params = {
        "num_iterations_T": 100000,   # increase to 2e4+ for real runs
        "m": 6,                   # must be even (pairs)
        "proposal_std_z": .001,     # z RW scale inside truncated normal
        "proposal_std_mu": 0.50,    # mu RW scale
        "prior_mean": 0.0,
        "prior_std": 5.0,
    }

    # ----- Run Gibbs
    results = run_gibbs_sampler_mle_jax_logistic(key, mu_star=mu_star, params=params)

    # ----- Basic diagnostics
    T = params["num_iterations_T"]
    burn = int(0.3 * T)
    thin = 1

    print("\n--- Diagnostics ---")
    print("Acceptance rates:")
    print(f"  mu:   {float(results['mu_acceptance_rate']):.3f}")
    print(f"  z_i:  {float(results['z_i_acceptance_rate']):.3f}")
    print(f"  pair: {float(results['pair_acceptance_rate']):.3f}")

    summ = summarize_chain("mu", results["mu_chain"], burn=burn, thin=thin)
    print("\nmu posterior summary (after burn):")
    for k, v in summ.items():
        if k != "name":
            print(f"  {k}: {v}")

    # Constraint check: should hover near 0
    checks = results["grad_likelihood_checks"]
    print("\nConstraint check:")
    print(f"  mean(abs(check)) after burn: {float(jnp.mean(jnp.abs(checks[burn:]))):.3e}")
    print(f"  max(abs(check))  after burn: {float(jnp.max(jnp.abs(checks[burn:]))):.3e}")

    # ----- Plots
    quick_plot_mu(results["mu_chain"], burn=burn, thin=thin, title="mu chain (Logistic-loc | MLE constraint)")
    quick_hist_mu(results["mu_chain"], burn=burn, thin=thin, title="mu posterior (approx) after burn")
    quick_plot_constraint(results["grad_likelihood_checks"], title="Constraint check over iterations")

    # Optional: check that x stays on manifold (spot-check last x)
    x_last = results["x_chain"][-1]
    constraint_last = sum_psi_logistic(x_last - mu_star)
    print(f"\nConstraint at last x: {float(constraint_last):.3e}")

if __name__ == "__main__":
    main()