In [None]:
# =========================
# LAPLACE (location), scale b known
# Conditional on MLE median = mu_star (interpreted as: exactly m/2 points < mu_star and m/2 > mu_star)
# Two x-updates provided:
#   (A) full resampling each sweep (simplest, very stable)
#   (B) pairwise resampling (keeps the same "below/above" assignment, like your Student pairing style)
# Then Gibbs: mu | x (MH) and x | mu, median-constraint
# =========================

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
from jax import random, jit, vmap

from jax.scipy.stats import norm
from tqdm import tqdm

EPS_U = 1e-12

# -------------------------
# Laplace logpdf
# -------------------------
@jit
def laplace_logpdf(x, loc, b):
    return -jnp.log(2.0 * b) - jnp.abs(x - loc) / b

# -------------------------
# Laplace CDF / PPF (inverse CDF)
# CDF:
#   x < loc: 0.5 * exp((x-loc)/b)
#   x >=loc: 1 - 0.5*exp(-(x-loc)/b)
# PPF:
#   u < 0.5: loc + b * log(2u)
#   u>=0.5: loc - b * log(2(1-u))
# -------------------------
@jit
def laplace_cdf(x, loc, b):
    left = 0.5 * jnp.exp((x - loc) / b)
    right = 1.0 - 0.5 * jnp.exp(-(x - loc) / b)
    return jnp.where(x < loc, left, right)

@jit
def laplace_ppf(u, loc, b):
    u = jnp.clip(u, EPS_U, 1.0 - EPS_U)
    left = loc + b * jnp.log(2.0 * u)
    right = loc - b * jnp.log(2.0 * (1.0 - u))
    return jnp.where(u < 0.5, left, right)

# -------------------------
# Truncated Laplace sampling via inverse CDF
# sample X ~ Laplace(loc,b) truncated to [a,bnd] (a can be -inf, bnd can be +inf)
# We'll implement the only two truncations we need:
#   (-inf, mu_star) and (mu_star, +inf)
# -------------------------
@jit
def sample_laplace_trunc_left(key, loc, b, upper):
    # X in (-inf, upper)
    Fu = laplace_cdf(upper, loc, b)
    Fu = jnp.clip(Fu, EPS_U, 1.0 - EPS_U)
    u = random.uniform(key, shape=(), minval=EPS_U, maxval=Fu)
    return laplace_ppf(u, loc, b)

@jit
def sample_laplace_trunc_right(key, loc, b, lower):
    # X in (lower, +inf)
    Fl = laplace_cdf(lower, loc, b)
    Fl = jnp.clip(Fl, EPS_U, 1.0 - EPS_U)
    u = random.uniform(key, shape=(), minval=Fl, maxval=1.0 - EPS_U)
    return laplace_ppf(u, loc, b)

# Vectorized versions
sample_laplace_trunc_left_batch  = vmap(sample_laplace_trunc_left,  in_axes=(0, None, None, None))
sample_laplace_trunc_right_batch = vmap(sample_laplace_trunc_right, in_axes=(0, None, None, None))

# -------------------------
# (A) Full resampling x|mu, median constraint
# Strategy: pick a random permutation; first half forced < mu_star, second half forced > mu_star
# This guarantees the "median=M" (subgradient / order-stat sense) constraint with nonzero measure.
# -------------------------
@jit
def update_x_full_resample_laplace(key, x_current, mu_current, mu_star, b):
    m = x_current.shape[0]
    assert m % 2 == 0
    half = m // 2

    key_perm, key_left, key_right = random.split(key, 3)
    perm = random.permutation(key_perm, m)

    keys_left  = random.split(key_left,  half)
    keys_right = random.split(key_right, half)

    x_left  = sample_laplace_trunc_left_batch(keys_left,  mu_current, b, mu_star)   # (-inf, mu_star)
    x_right = sample_laplace_trunc_right_batch(keys_right, mu_current, b, mu_star) # (mu_star, +inf)

    x_new_perm = jnp.concatenate([x_left, x_right], axis=0)

    # undo permutation
    invperm = jnp.argsort(perm)
    x_new = x_new_perm[invperm]
    return x_new

# -------------------------
# (B) Pairwise update, keeping "below/above" assignment fixed:
# We maintain an index split: first half are always "left", second half "right",
# and only resample them from the corresponding truncated distributions.
# This is extremely stable and simple, but explores less label-symmetry than (A).
# -------------------------
@jit
def update_x_pairwise_laplace(key, x_current, mu_current, mu_star, b):
    m = x_current.shape[0]
    assert m % 2 == 0
    half = m // 2

    key_perm, key_left, key_right = random.split(key, 3)
    perm = random.permutation(key_perm, m)
    x_perm = x_current[perm]

    # pair layout doesn't matter; we just resample halves
    keys_left  = random.split(key_left,  half)
    keys_right = random.split(key_right, half)

    x_left  = sample_laplace_trunc_left_batch(keys_left,  mu_current, b, mu_star)
    x_right = sample_laplace_trunc_right_batch(keys_right, mu_current, b, mu_star)

    x_new_perm = jnp.concatenate([x_left, x_right], axis=0)

    invperm = jnp.argsort(perm)
    x_new = x_new_perm[invperm]
    return x_new

# -------------------------
# Posterior for mu given x (Laplace likelihood + Normal prior)
# -------------------------
@jit
def unnormalized_posterior_mu_logpdf_laplace(mu, x, prior_loc, prior_scale, b):
    loglik = jnp.sum(laplace_logpdf(x, loc=mu, b=b))
    logprior = norm.logpdf(mu, loc=prior_loc, scale=prior_scale)
    return loglik + logprior

@jit
def update_mu_metropolis_jax_laplace(key, mu_current, x_current, sigma_mu, prior_loc, prior_scale, b):
    key_prop, key_u = random.split(key, 2)
    mu_cand = mu_current + sigma_mu * random.normal(key_prop)

    lp_cur = unnormalized_posterior_mu_logpdf_laplace(mu_current, x_current, prior_loc, prior_scale, b)
    lp_cand = unnormalized_posterior_mu_logpdf_laplace(mu_cand, x_current, prior_loc, prior_scale, b)

    log_alpha = lp_cand - lp_cur
    log_alpha = jnp.where(jnp.isfinite(log_alpha), log_alpha, -jnp.inf)
    u = random.uniform(key_u, shape=(), minval=EPS_U, maxval=1.0)
    acc = jnp.log(u) < log_alpha
    mu_new = jnp.where(acc, mu_cand, mu_current)
    return mu_new, acc

# -------------------------
# Constraint check for median:
# Here we enforce: exactly half < mu_star and half > mu_star (up to numerical ties).
# -------------------------
@jit
def median_constraint_check(x, mu_star):
    n_left = jnp.sum(x < mu_star)
    n_right = jnp.sum(x > mu_star)
    return n_left, n_right

# -------------------------
# Gibbs runner
# Choose x_update_mode in {"full", "pair"}.
# -------------------------
def run_gibbs_sampler_mle_jax_laplace(key, mu_star: float, params: dict) -> dict:
    T = params["num_iterations_T"]
    m = params["m"]
    b = params.get("b", 1.0)
    x_update_mode = params.get("x_update_mode", "full")  # "full" or "pair"

    assert m % 2 == 0, "m must be even to enforce the median constraint via half/half split."

    mus = jnp.zeros(T + 1)
    xs  = jnp.zeros((T + 1, m))

    # easy feasible init: half below, half above (or all at mu_star would violate strict split)
    half = m // 2
    x0 = jnp.concatenate([
        (mu_star - 1.0) * jnp.ones(half),
        (mu_star + 1.0) * jnp.ones(half)
    ])
    xs = xs.at[0, :].set(x0)
    mus = mus.at[0].set(mu_star)

    mu_acc = 0
    checks_left = jnp.zeros(T, dtype=jnp.int32)
    checks_right = jnp.zeros(T, dtype=jnp.int32)

    for t in tqdm(range(1, T + 1), desc=f"Running Gibbs (Laplace-loc, mode={x_update_mode})"):
        key, key_mu, key_x = random.split(key, 3)

        # Step (a): mu | x
        x_current = xs[t - 1]
        mu_new, acc = update_mu_metropolis_jax_laplace(
            key_mu, mus[t - 1], x_current,
            params["proposal_std_mu"],
            params["prior_mean"],
            params["prior_std"],
            b
        )
        mus = mus.at[t].set(mu_new)
        mu_acc += int(acc)

        # Step (b): x | mu, median constraint at mu_star
        if x_update_mode == "full":
            x_new = update_x_full_resample_laplace(key_x, x_current, mu_new, mu_star, b)
        elif x_update_mode == "pair":
            x_new = update_x_pairwise_laplace(key_x, x_current, mu_new, mu_star, b)
        else:
            raise ValueError("x_update_mode must be 'full' or 'pair'")

        xs = xs.at[t, :].set(x_new)

        nL, nR = median_constraint_check(x_new, mu_star)
        checks_left = checks_left.at[t - 1].set(nL)
        checks_right = checks_right.at[t - 1].set(nR)

    mu_acc_rate = mu_acc / T
    print("\n--- Sampling Complete (Laplace-loc) ---")
    print(f"Mu Acceptance Rate: {mu_acc_rate:.4f}")
    print(f"Median constraint check (last): n_left={int(checks_left[-1])}, n_right={int(checks_right[-1])}")

    return {
        "mu_acceptance_rate": mu_acc_rate,
        "mu_chain": mus,
        "x_chain": xs,
        "n_left_chain": checks_left,
        "n_right_chain": checks_right,
        "b": b,
        "x_update_mode": x_update_mode,
    }

# =========================
# Quick test script
# =========================
if __name__ == "__main__":
    key = random.PRNGKey(0)

    params = {
        "num_iterations_T": 2000,
        "m": 200,                 # even
        "b": 1.0,                 # known scale
        "proposal_std_mu": 0.5,
        "prior_mean": 0.0,
        "prior_std": 5.0,
        "x_update_mode": "full",  # try "pair" as well
    }
    mu_star = 0.0

    out = run_gibbs_sampler_mle_jax_laplace(key, mu_star, params)

    # basic summaries
    T = params["num_iterations_T"]
    burn = int(0.3 * T)
    mu_chain = out["mu_chain"][burn:]
    print("\nmu posterior summary (after burn):")
    print("  mean:", float(jnp.mean(mu_chain)))
    print("  std :", float(jnp.std(mu_chain)))
    print("  q05 :", float(jnp.quantile(mu_chain, 0.05)))
    print("  q50 :", float(jnp.quantile(mu_chain, 0.50)))
    print("  q95 :", float(jnp.quantile(mu_chain, 0.95)))

    # constraint check
    nL = out["n_left_chain"][burn:]
    nR = out["n_right_chain"][burn:]
    print("\nconstraint counts after burn:")
    print("  unique n_left:", jnp.unique(nL))
    print("  unique n_right:", jnp.unique(nR))