# Reproduction: Neal's Funnel

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from numba import njit
from ipywidgets import interact, FloatLogSlider
import warnings
warnings.filterwarnings("ignore")

# --- Neal's funnel potential ---
@njit
def grad_U(z):
    y, x = z
    grad_y = y - 0.5 * np.exp(-y) * (x**2)
    grad_x = np.exp(-y) * x
    return np.array([grad_y, grad_x])

@njit
def laplacian_U(z):
    y, x = z
    return 1.0 + np.exp(-y) + 0.5 * np.exp(-y) * x**2

# --- BAOAB integrator ---
@njit
def step_BAOAB(x, p, z, h, gamma, alpha, beta):
    dt = h
    p -= 0.5*dt*grad_U(x)
    x += 0.5*dt*p
    c = np.exp(-gamma*dt)
    p = c*p + np.sqrt((1-c**2) / beta)*np.random.randn(2)
    x += 0.5*dt*p
    p -= 0.5*dt*grad_U(x)
    return x, p, z, dt

@njit
def g(x):
    return np.linalg.norm(grad_U(x))

@njit
def psi(z, m, M, r):
    return m * (z**r + M) / (z**r + m)

# --- ZBAOABZ integrator ---
@njit
def step_ZBAOABZ(x, p, z, dtau, gamma, alpha, beta):
    rho = np.exp(-alpha*0.5*dtau)
    z = rho*z + (1-rho) * g(x) / alpha
    dt = psi(z, m, M, r) * dtau

    p -= 0.5*dt*grad_U(x)
    x += 0.5*dt*p
    c = np.exp(-gamma*dt)
    p = c*p + np.sqrt((1-c**2) / beta)*np.random.randn(2)
    x += 0.5*dt*p
    p -= 0.5*dt*grad_U(x)

    rho = np.exp(-alpha*0.5*dtau)
    z = rho*z + (1-rho) * g(x) / alpha
    return x, p, z, dt

# --- Run sampler with optional trace recording ---
@njit
def run_sampler(stepper, nsteps, h, gamma, alpha, beta, burnin=1000, record_trace=False):
    x = np.array([5.0, 0.0])
    p = np.array([0.0, 0.0])
    z = 0.0
    samples = np.zeros((nsteps, 2))
    traces = np.zeros((nsteps, 6))  # [y, x, p_y, p_x, dt, T_conf]

    for t in range(nsteps + burnin):
        x, p, z, dt = stepper(x, p, z, h, gamma, alpha, beta)
        if t >= burnin:
            idx = t - burnin
            samples[idx, 0] = x[0]   # y
            samples[idx, 1] = x[1]   # x
            if record_trace:
                grad = grad_U(x)
                lapl = laplacian_U(x)
                T_conf = np.dot(grad, grad) / lapl

                traces[idx, 0] = x[0]    # y
                traces[idx, 1] = x[1]    # x
                traces[idx, 2] = p[0]    # p_y
                traces[idx, 3] = p[1]    # p_x
                traces[idx, 4] = dt      # dt
                traces[idx, 5] = T_conf  # configurational T
    return samples, traces

# --- Effective Sample Size ---
def autocorr_func_1d(x, max_lag=2000):
    n = len(x)
    x = x - np.mean(x)
    result = np.correlate(x, x, mode='full')
    acf = result[result.size//2:] / result[result.size//2]
    return acf[:max_lag]

def ess(x, max_lag=2000):
    acf = autocorr_func_1d(x, max_lag)
    positive_acf = acf[acf > 0]
    tau = 1 + 2 * np.sum(positive_acf[1:])
    return len(x) / tau

# --- Target log density for contours ---
def log_p(y, x):
    return -0.5*y**2 - 0.5*np.exp(-y)*x**2

xs = np.linspace(-10, 10, 400)
ys = np.linspace(-6, 6, 300)
X, Y = np.meshgrid(xs, ys)
LOGZ = log_p(Y, X)
vmax, vmin = LOGZ.max(), LOGZ.max() - 40
levels = np.linspace(vmin, vmax, 60)
m, M, r = 0.01, 60, 0.5

def plot_samplers(alpha=1.0, h=0.01, gamma=1.0, beta=1.0):
    nsteps = int(1e4)  # keep interactive speed reasonable

    samples_baoab, traces_baoab = run_sampler(step_BAOAB, nsteps, h, gamma, alpha, beta, record_trace=True)
    samples_zbaoabz, traces_zbaoabz = run_sampler(step_ZBAOABZ, nsteps, h, gamma, alpha, beta, record_trace=True)

    # compute ESS for y trace
    ess_baoab = ess(traces_baoab[:,0])
    ess_zbaoabz = ess(traces_zbaoabz[:,0])

    # compute kinetic temperature for traces
    T_kin_baoab = np.mean(np.sum(traces_baoab[:,2:4]**2, axis=1)) / 2
    T_kin_zbaoabz = np.mean(np.sum(traces_zbaoabz[:,2:4]**2, axis=1)) / 2

    # compute mean configurational temperatures
    T_conf_mean_baoab = np.mean(traces_baoab[:,5])
    T_conf_mean_zbaoabz = np.mean(traces_zbaoabz[:,5])

    # compute running average of observable y
    y_avg_baoab = np.cumsum(traces_baoab[:,0]) / (np.arange(len(traces_baoab)) + 1)
    y_avg_zbaoabz = np.cumsum(traces_zbaoabz[:,0]) / (np.arange(len(traces_zbaoabz)) + 1)

    fig = plt.figure(figsize=(12, 24))
    gs = fig.add_gridspec(8, 2, height_ratios=[2, 1, 1, 1, 1, 1, 1, 1], hspace=0.5)

    # --- Row 0: Contours ---
    ax0 = fig.add_subplot(gs[0, 0])
    ax0.contourf(X, Y, LOGZ, levels=levels, cmap='viridis')
    ax0.plot(samples_baoab[:, 1], samples_baoab[:, 0], lw=0.7, color='red', alpha=0.7)
    ax0.set_title(f'BAOAB (h={h}, γ={gamma}, α={alpha})')
    ax0.set_xlabel('x'); ax0.set_ylabel('y')

    ax1 = fig.add_subplot(gs[0, 1])
    ax1.contourf(X, Y, LOGZ, levels=levels, cmap='viridis')
    ax1.plot(samples_zbaoabz[:, 1], samples_zbaoabz[:, 0], lw=0.7, color='red', alpha=0.7)
    ax1.set_title(f'ZBAOABZ (h={h}, γ={gamma}, α={alpha})')
    ax1.set_xlabel('x'); ax1.set_ylabel('y')

    # --- Row 1: Position traces ---
    ax_left = fig.add_subplot(gs[1, 0])
    ax_left.plot(traces_baoab[:,0], lw=0.7, label="y")
    ax_left.plot(traces_baoab[:,1], lw=0.7, label="x")
    ax_left.set_title("BAOAB trace: positions"); ax_left.set_xlabel("Step"); ax_left.legend()

    ax_right = fig.add_subplot(gs[1, 1])
    ax_right.plot(traces_zbaoabz[:,0], lw=0.7, label="y")
    ax_right.plot(traces_zbaoabz[:,1], lw=0.7, label="x")
    ax_right.set_title("ZBAOABZ trace: positions"); ax_right.set_xlabel("Step"); ax_right.legend()

    # --- Row 2: Momentum traces ---
    ax_left = fig.add_subplot(gs[2, 0])
    ax_left.plot(traces_baoab[:,2], lw=0.7, label="p_y")
    ax_left.plot(traces_baoab[:,3], lw=0.7, label="p_x")
    ax_left.set_title("BAOAB trace: momenta"); ax_left.set_xlabel("Step"); ax_left.legend()

    ax_right = fig.add_subplot(gs[2, 1])
    ax_right.plot(traces_zbaoabz[:,2], lw=0.7, label="p_y")
    ax_right.plot(traces_zbaoabz[:,3], lw=0.7, label="p_x")
    ax_right.set_title("ZBAOABZ trace: momenta"); ax_right.set_xlabel("Step"); ax_right.legend()

    # --- Row 3: Step size traces ---
    ax_left = fig.add_subplot(gs[3, 0])
    ax_left.plot(traces_baoab[:,4], lw=0.7)
    ax_left.set_title("BAOAB trace: dt (step size)"); ax_left.set_xlabel("Step")

    ax_right = fig.add_subplot(gs[3, 1])
    ax_right.plot(traces_zbaoabz[:,4], lw=0.7)
    ax_right.set_title("ZBAOABZ trace: dt (step size)"); ax_right.set_xlabel("Step")

    # --- Row 4: Configurational Temperature traces ---
    ax_left = fig.add_subplot(gs[4, 0])
    ax_left.plot(traces_baoab[:,5], lw=0.7, label="T_conf")
    ax_left.hlines(T_kin_baoab, 0, len(traces_baoab), color='orange', lw=1.5, linestyle='--', label=f"T_kin={T_kin_baoab:.3f}")
    ax_left.hlines(T_conf_mean_baoab, 0, len(traces_baoab), color='red', lw=1.5, linestyle=':', label=f"T_conf_mean={T_conf_mean_baoab:.3f}")
    ax_left.set_title("BAOAB trace: Configurational vs Kinetic T"); ax_left.set_xlabel("Step"); ax_left.legend()

    ax_right = fig.add_subplot(gs[4, 1])
    ax_right.plot(traces_zbaoabz[:,5], lw=0.7, label="T_conf")
    ax_right.hlines(T_kin_zbaoabz, 0, len(traces_zbaoabz), color='orange', lw=1.5, linestyle='--', label=f"T_kin={T_kin_zbaoabz:.3f}")
    ax_right.hlines(T_conf_mean_zbaoabz, 0, len(traces_zbaoabz), color='red', lw=1.5, linestyle=':', label=f"T_conf_mean={T_conf_mean_zbaoabz:.3f}")
    ax_right.set_title("ZBAOABZ trace: Configurational vs Kinetic T"); ax_right.set_xlabel("Step"); ax_right.legend()

    # --- Row 5: ESS comparison ---
    ax_ess = fig.add_subplot(gs[5, :])
    ax_ess.bar(["BAOAB", "ZBAOABZ"], [ess_baoab, ess_zbaoabz], color=["blue", "green"], alpha=0.7)
    ax_ess.set_title("Effective Sample Size (ESS) for y"); ax_ess.set_ylabel("ESS")

    # --- Row 6: Histogram of T_conf vs T_kin ---
    ax_hist = fig.add_subplot(gs[6, :])
    ax_hist.hist(traces_baoab[:,5], bins=50, alpha=0.5, label="BAOAB T_conf")
    ax_hist.hist(traces_zbaoabz[:,5], bins=50, alpha=0.5, label="ZBAOABZ T_conf")
    ax_hist.axvline(T_kin_baoab, color='blue', linestyle='--', label="BAOAB T_kin")
    ax_hist.axvline(T_kin_zbaoabz, color='green', linestyle='--', label="ZBAOABZ T_kin")
    ax_hist.set_title("Histogram of Configurational vs Kinetic Temperature")
    ax_hist.set_xlabel("Temperature"); ax_hist.set_ylabel("Frequency"); ax_hist.legend()

    # --- Row 7: Running average of observable y and x ---
    # Compute running averages
    y_avg_baoab = np.cumsum(traces_baoab[:,0]) / (np.arange(len(traces_baoab)) + 1)
    x_avg_baoab = np.cumsum(traces_baoab[:,1]) / (np.arange(len(traces_baoab)) + 1)

    y_avg_zbaoabz = np.cumsum(traces_zbaoabz[:,0]) / (np.arange(len(traces_zbaoabz)) + 1)
    x_avg_zbaoabz = np.cumsum(traces_zbaoabz[:,1]) / (np.arange(len(traces_zbaoabz)) + 1)

    ax_left = fig.add_subplot(gs[7, 0])
    ax_left.plot(y_avg_baoab, lw=0.7, color='purple', label='y_avg')
    ax_left.plot(x_avg_baoab, lw=0.7, color='green', label='x_avg')
    ax_left.set_title("BAOAB: Running averages of y and x")
    ax_left.set_xlabel("Step"); ax_left.set_ylabel("Average")
    ax_left.legend()

    ax_right = fig.add_subplot(gs[7, 1])
    ax_right.plot(y_avg_zbaoabz, lw=0.7, color='purple', label='y_avg')
    ax_right.plot(x_avg_zbaoabz, lw=0.7, color='green', label='x_avg')
    ax_right.set_title("ZBAOABZ: Running averages of y and x")
    ax_right.set_xlabel("Step"); ax_right.set_ylabel("Average")
    ax_right.legend()

    plt.show()

# --- Sliders ---
interact(plot_samplers,
         alpha=FloatLogSlider(value=1.0, base=10, min=-2, max=2, step=0.1, description='alpha'),
         h=FloatLogSlider(value=0.01, base=10, min=-3, max=0, step=0.1, description='h'),
         gamma=FloatLogSlider(value=1.0, base=10, min=-2, max=2, step=0.1, description='gamma'),
         beta=FloatLogSlider(value=1.0, base=10, min=-4, max=4, step=0.1, description='beta'));


interactive(children=(FloatLogSlider(value=1.0, description='alpha', max=2.0, min=-2.0), FloatLogSlider(value=…

# Motivating toy example: Multimodal GMM with varying mode sizes and noisy gradients

In [None]:
# Have 4 modes, all of different sizes with entropic barriers. Compare SGLD, SA-SGLD. SGHMC, SA-SGHMC. pSGLD and AdamSGLD.
# KL divergence / Wasserstein-2 between empirical marginal and ground truth (available for these toy problems). Compute on a grid or analytically where possible.
# ESS per gradient-eval (use spectral density estimator).
# Autocorrelation time for key coordinates.
# Trace plots and kernel density overlays with ground-truth contours.
# Acceptance-free diagnostics: sample quantiles, mean & covariance bias.


# Optimizer

In [None]:
# Study its role at temperature-zero (as an optimizer).
# Start without gradient noise.
# Compare SGD, mSGD, SA-SGD, SA-mSGD.
# Contrast to Adam/RMSProp.
# Then simulate artificial gradient noise.

#  Hierarchical Model for Radon Measurements

In [None]:
# Show a nice hierarchical problem for Bayesian learning with curved modes with simulated gradient noise and full-batch size.
# Then repeat with smaller batch size and tune temperature.
# https://www.pymc.io/projects/examples/en/2021.11.0/variational_inference/GLM-hierarchical-advi-minibatch.html
# https://brendanhasz.github.io/2018/11/15/hmm-vs-gp-part2.html

# Bayesian Neural Networks

In [None]:
# Use Horseshoe or Laplace prior to realize benefit against SGLD/SGHMC. Remove cold posterior effect in model architecture and data augmentation. Use simple MLP MNIST. Initialize with L-BFGS.
# Show improvements in NLL/Brier/ECE/Test acc. 
# Show pSGLD fails as we increase temperature and it only works because we are averaging over optimizer iterates.
# Show performance with change in batch size.
# Reference: https://arxiv.org/html/2303.05101v4

# Splitting vs Euler-Maruyama