In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.lines import Line2D
from matplotlib.patches import Patch

import pypoman

from copy import deepcopy

import hopsy

plt.rcParams["font.family"] = "serif"
#plt.rcParams['font.serif'] = ["Times"]
plt.rcParams["mathtext.fontset"] = "dejavuserif"
plt.rcParams['axes.linewidth'] = .2
plt.rcParams['xtick.major.size'] = 3
plt.rcParams['xtick.major.width'] = .2
plt.rcParams['ytick.major.size'] = 3
plt.rcParams['ytick.major.width'] = .2

target_style = {
    "neff": {'color': 'C0'},
    "neff/t": {'color': 'C0'},
    "rhat": {'color': 'C1', 'linestyle': 'dashed'},
    "acc": {'color': 'C2'},
    "esjd": {'color': 'C3'},
    "esjd/t": {'color': 'C3'},
    "t": {'color': 'C4'},
    "n": {'color': 'C5'},
}


tuning_target_style = {
    "Acceptance\nRate\n(1-norm)": {'color': 'C1'},
    "Acceptance\nRate\n(2-norm)": {'color': 'C2'},
    "ESJD": {'color': 'C1'},
    "1,5-ESJD": {'color': 'C1'},
    "ESJD/s": {'color': 'C1'},
    "1,5-ESJD/s": {'color': 'C1'},
}

fs = 10


dpi=300
large_dpi=300

figsize=np.array([6.4, 4.8])
broad_figsize=np.array([6.4, 2.8])
mid_figsize=np.array([6.4, 4.8])
large_figsize=np.array([6.4, 3.6]) * 2

fmt = ".png"

from experimental_setup import *

def collect_connected_sets(x, good):
    cs_good = []
    cs_bad = []
    last_good_start = None
    last_bad_start = x[0]
    for i in range(len(x)):
        if good[i] and last_good_start is None:
            last_good_start = x[i]
            cs_bad.append((last_bad_start, x[i]))
        if not good[i] and last_good_start is not None:
            cs_good.append((last_good_start, x[i-1]))
            last_good_start = None
            last_bad_start = x[i-1]
            
    if good[-1] and last_good_start is not None:
        cs_good.append((last_good_start, x[-1]))
    else:
        cs_bad.append((last_bad_start, x[-1]))
        
            
    return cs_good, cs_bad

def choose_good(x, lims):
    a, b = lims
    i = np.where(x >= a)
    j = np.where(x <= b)

    return (np.array(list(set(i[0]).intersection(set(j[0])))),)
    

# Isotope labeling enrichments throughout time

In [None]:
A, P = np.loadtxt("data/A_data", delimiter=","), np.loadtxt("data/P_data", delimiter=",")

x = A[:,0]
data = np.hstack([A[:,1:2], A[:,2:3], P[:,1:2], P[:,2:3]])

ts = [(0, 2), (1, 5), (2, 8), ("\infty", 20)]

fig, ax = plt.subplots(1, 1, dpi=dpi)

for i in range(data.shape[1]):
    ax.plot(x, data[:,i], color=cm.tab10(i))
    
for i, t in ts:
    ax.text(t-0.2, 1.03, r"$t_" + str(i) + "$")
    ax.plot([t, t], [-1, 2], color='gray', linestyle='dashed', zorder=-1, alpha=.8)
    
ax.set_ylim([-.01, 1.01])
    
ax.set_ylabel("Isotopomer fraction")
ax.set_xlabel(r"$t$")

ax.legend([r"$m_{00}$", r"$m_{10}$", r"$p_{00}$", r"$p_{10}$"])
fig.savefig("img/simplicus" + fmt)
plt.show()

# Acceptance rate score functions

In [None]:
s = np.linspace(-4, 6, 200)
a = -1 / (1 + np.exp(-s)) + 1

target = 0.234

p1_score = lambda x: 1-np.abs(a - target)
p2_score = lambda x: 1-(a - target)**2
p1_score_norm = lambda x: 1 - (np.abs(a - target)) / ((a < 0.234) * target + (a > 0.234) * (1-target))
p2_score_norm = lambda x: 1 - ((a - target) / ((a < 0.234) * target + (a > 0.234) * (1-target)))**2

s_star = s[np.argmax(p1_score(a))]

plt.figure(dpi=dpi, figsize=figsize)

plt.plot(s, a, label=r'$\alpha$')
plt.plot(s, p1_score(a), label=r'$1 - |\alpha - \alpha^*|$')
plt.plot(s, p2_score(a), label=r'$1 - (\alpha - \alpha^*)^2$')
plt.plot(s, p1_score_norm(a), label=r'$1 - \frac{(\alpha - \alpha^*)}{Z(\alpha)}$')
plt.plot(s, p2_score_norm(a), label=r'$1 - \frac{(\alpha - \alpha^*)^2}{Z(\alpha)^2}$')

plt.plot([s_star, s_star], [-1, 2], linestyle='dashed', color='gray')
plt.plot([-10, 10], [target, target], linestyle='dashed', color='gray')

plt.xticks([-4, -2, 0, s_star, 2, 4, 6], [-4, -2, 0, r"$s^*$", 2, 4, 6])
plt.yticks([0, .2, 0.234, .4, .6, .8, 1], [0, .2, r"$\alpha^*$",.4, .6, .8, 1])

plt.xlim([-4.5, 6.5])
plt.ylim([-.01, 1.01])

plt.ylabel(r"Acceptance rate / score")
plt.xlabel(r"log step size")

plt.legend()
plt.savefig("img/scorefunctions" + fmt)
plt.show()

# Gaussian process kernel hyperparametrization

In [None]:
def kernel(x, y, l = 1, sigma = 1):
    sqdist = np.sum(x**2, 1).reshape(-1, 1) + np.sum(y**2, 1) - 2 * np.dot(x, y.T)
    return sigma**2 * np.exp(-0.5 / l**2 * sqdist)
    
def posterior(x, x_obs, y_obs, l = 1, sigma_f = 1, sigma_y = 0, prior_mu = 0):
    K = kernel(x_obs, x_obs, l, sigma_f) + sigma_y * np.eye(len(x_obs))
    K_s = kernel(x_obs, x, l, sigma_f)
    K_ss = kernel(x, x, l, sigma_f)
    K_inv = np.linalg.inv(K)
    
    mu = prior_mu + K_s.T.dot(K_inv).dot(y_obs - prior_mu)
    cov = K_ss - K_s.T.dot(K_inv).dot(K_s)
    
    return mu, cov

x_obs = np.array([5]).reshape(-1, 1)
y_obs = np.array([np.exp(-(x-4)**2) for x in x_obs])

X = np.linspace(0, 10, 100).reshape(-1, 1)
Y = np.array([np.exp(-(x-4)**2) for x in X])

fig, ax = plt.subplots(2, 4, dpi=dpi, figsize=(12, 6))

sigma_fs = [2, 1]
sigma_ys = [.01, .5]
lengths = [1, 5]

for i in range(ax.shape[0]):
    for j in range(int(ax.shape[1] / 2)):
        for k in range(int(ax.shape[1] / 2)):
            mu, cov = posterior(X, x_obs, y_obs, sigma_f = sigma_fs[k], sigma_y = sigma_ys[i], l=lengths[j])

            ax[i,2*j+k].set_title(
                r"$" + 
                r"\alpha = " + str(sigma_fs[k]) +
                r", \, " + 
                r"\sigma_y = " + str(sigma_ys[i]) + 
                r", \, " + 
                r"\ell = " + str(lengths[j])
                + r"$")
            ax[i,2*j+k].plot(X, mu)
            ax[i,2*j+k].fill_between(X.flatten(), 
                             (mu.T[0] - np.sqrt(np.diag(cov))).flatten(), 
                             (mu.T[0] + np.sqrt(np.diag(cov))).flatten(), alpha=0.1)

            ax[i,2*j+k].set_ylim(-2.1, 2.1)

            ax[i,2*j+k].plot(X, Y, alpha=0.5, linestyle='dashed', zorder=10)
            ax[i,2*j+k].scatter(x_obs, y_obs)
    
            if i == 0:
                ax[i,2*j+k].set_xticks([])
            else:
                ax[i,2*j+k].set_xlabel("$x$")
                
            if 2*j+k != 0:
                ax[i,2*j+k].set_yticks([])
            else:
                ax[i,2*j+k].set_ylabel("$y$")
                
            
custom_lines = [Line2D([0], [0], color='C0'),
                Patch(color='C0', alpha=.1),
                Line2D([0], [0], color='C1', alpha=.5, linestyle='dashed'),
                Line2D([0], [0], marker='o', color='w', markersize=10, markerfacecolor='C1'),
               ]

fig.subplots_adjust(wspace=.1)
fig.legend(custom_lines, ["GP mean", "GP std. dev.", "Ground truth", "Data"], bbox_to_anchor=(0.5, 0.05), loc='upper center', ncol=len(custom_lines))
fig.savefig("img/sqexp" + fmt, bbox_inches='tight')
plt.show()

# Credible intervals

In [None]:
a, n = 4, 2
A, b = [[1], [-1]], [10, 5]
mixture = hopsy.Mixture([hopsy.Gaussian(mean=[a*i]) for i in range(n)])
problem = hopsy.Problem(A, b, mixture)

c = 4
mcs = [hopsy.MarkovChain(problem, hopsy.GaussianHitAndRunProposal, starting_point=[a*n / 2]) for i in range(c)]
rngs = [hopsy.RandomNumberGenerator(seed=0, stream=i) for i in range(c)]

accrate, samples = hopsy.sample(mcs, rngs, 100000, n_threads=4)

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, dpi=dpi, figsize=np.sqrt(1.5)*figsize)


a, b, N = -4, 12, 200
L = b - a
x = np.linspace(a, b, N)
density = np.array([np.exp(-mixture.compute_negative_log_likelihood([_x])) for _x in x])
ax1.plot(x, density)
ax2.plot(x, density)

alpha = 0.75
samples = samples.flatten()

## compute cris

sorted_samples = sorted(samples)
i, j = int(len(samples) * (1 - alpha) / 2), int(np.floor(len(samples) * (1 + alpha) / 2))
cri = (sorted_samples[i], sorted_samples[j])

n = int((cri[1] - cri[0]) / L * N)
_x = np.linspace(cri[0], cri[1], n)
#print(n)
_d = np.array([np.exp(-mixture.compute_negative_log_likelihood([__x])) for __x in _x])
ax1.fill_between(_x, _d * 0 - .01, _d, color='C1', alpha=.5, label='Credible\ninterval')


ax1.set_xlim([-3, 7])

ax1.set_ylabel(r"$f(\theta)$")
#ax1.set_xlabel(r"$\theta$")

    
## compute hdis
nlls = [np.exp(-mixture.compute_negative_log_likelihood([_x])) for _x in samples]
sorted_nlls = sorted(nlls)
p_alpha = sorted_nlls[int((1-alpha) * len(samples))]

is_hdi_sample = (nlls > p_alpha)

foo = np.array([samples, is_hdi_sample]).T
foo = foo[foo[:, 0].argsort()]

samples = foo[:,0]
is_hdi_sample = foo[:,1]

hdis = []
last_start = None
for i in range(len(samples)):
    if is_hdi_sample[i] and last_start is None:
        last_start = samples[i]
    if not is_hdi_sample[i] and last_start is not None:
        hdis.append((last_start, samples[i-1]))
        last_start = None
        


# hdi level
ax2.plot([a, b], [p_alpha, p_alpha], linestyle='dashed', color='gray')
    
# hdis
first = True
for hdi in hdis:
    n = int((hdi[1] - hdi[0]) / L * N)
    _x = np.linspace(hdi[0], hdi[1], n)
    #print(n)
    _d = np.array([np.exp(-mixture.compute_negative_log_likelihood([__x])) for __x in _x])
    if first:
        ax2.fill_between(_x, _d * 0, _d, color='C1', alpha=.5, label="High\ndensity\nregion")
        first =False
    else:
        ax2.fill_between(_x, _d * 0, _d, color='C1', alpha=.5)

ax2.set_xlim([-3, 7])

ax2.set_yticks([0, p_alpha, 0.2])
ax2.set_yticklabels([0, r"$p_{\alpha}$", 0.2])

ax2.set_ylabel(r"$f(\theta)$")
ax2.set_xlabel(r"$\theta$")
    
ax1.legend()
ax2.legend()
fig.savefig("img/cri-hdi" + fmt)
plt.show()

# hopsy demo

In [None]:
problem = hopsy.Problem([[2, 1], [-1, 0], [0, -1]], [5, 0, 0], hopsy.Gaussian(dim=2))

mc = hopsy.MarkovChain(problem, starting_point = [.1, .1])
rng = hopsy.RandomNumberGenerator(42)

acceptance_rate, draws = hopsy.sample(mc, rng, n_samples = 5000)

plt.figure(dpi=dpi)
plt.scatter(draws[:,:,0], draws[:,:,1], alpha=.2)
plt.ylim([-1, 4])
plt.xlim([-1, 4])
plt.savefig("img/codesample" + fmt)
plt.show()

# Sandwiching ratio

In [None]:
alpha = .2

u, l = 10, -10

def draw_polytope(A, b, ax):
    V = np.array(pypoman.compute_polytope_vertices(np.array(A), np.array(b)))
    for i in range(len(V)-1):
        ax.plot([V[i,0], V[i+1,0]], [V[i,1], V[i+1,1]], color='C0')
        
        O = u * (V[i+1] - V[i]) + V[i]
        U = l * (V[i+1] - V[i]) + V[i]
        
        ax.plot([V[i+1, 0], O[0]], [V[i+1, 1], O[1]], color='C0', alpha=alpha)
        ax.plot([V[i, 0], U[0]], [V[i, 1], U[1]], color='C0', alpha=alpha)

    ax.plot([V[-1,0], V[0,0]], [V[-1,1], V[0,1]], color='C0')

    O = u * (V[-1] - V[0]) + V[0]
    U = l * (V[-1] - V[0]) + V[0]

    ax.plot([V[-1, 0], O[0]], [V[-1, 1], O[1]], color='C0', alpha=alpha)
    ax.plot([V[0, 0], U[0]], [V[0, 1], U[1]], color='C0', alpha=alpha)

    
def draw_min_inscribing_ball(A, b, x, ax):
    V = np.array(pypoman.compute_polytope_vertices(np.array(A), np.array(b)))
    x = x.reshape(-1)
    
    R = 0
    for i, v in enumerate(V):
        d = np.linalg.norm(v - x)
        if d > R:
            R = d
            
    ball = np.array([[R * np.sin(phi) + x[0], R * np.cos(phi) + x[1]] for phi in np.linspace(0, 2*np.pi)])
    
    ax.plot(ball[:,0], ball[:,1], color='C1', linestyle='dotted')
    
    return R


def draw_max_inscribed_ball(A, b, x, ax):
    x = x.reshape(-1)
    
    r = np.inf
    for i, a in enumerate(np.array(A)):
        d = np.abs(a.dot(x) + b[i]) / np.linalg.norm(a)
        if d < r:
            r = d
            
    ball = np.array([[r * np.sin(phi) + x[0], r * np.cos(phi) + x[1]] for phi in np.linspace(0, 2*np.pi)])
    
    ax.plot(ball[:,0], ball[:,1], color='C1', linestyle='dashed')
    
    return r
    
A, b = [[1, 0], [-1, 0], [-1, 1], [1, -1]], [1, 0, 0, 1]
x = hopsy.compute_chebyshev_center(hopsy.Problem(A, b))

rounded = hopsy.round(hopsy.Problem(A, b))
Ar, br = rounded.A, rounded.b
xr = hopsy.compute_chebyshev_center(rounded)

fig, axs = plt.subplots(1, 2, figsize=(9.6, 4.8), dpi=dpi)

axs[0].scatter([x[0]], [x[1]], color='C1', marker='x')
axs[1].scatter([xr[0]], [xr[1]], color='C1', marker='x')

draw_polytope(A, b, axs[0])
draw_polytope(Ar, br, axs[1])

R = draw_min_inscribing_ball(A, b, x, axs[0])
Rr = draw_min_inscribing_ball(Ar, br, xr, axs[1])

r = draw_max_inscribed_ball(A, b, x, axs[0])
rr = draw_max_inscribed_ball(Ar, br, xr, axs[1])

axs[0].set_xlim([-1, 2])
axs[1].set_xlim([-2, 2])

axs[0].set_ylim([-1.5, 1.5])
axs[1].set_ylim([-2, 2])

axs[0].text(-0.9, 1.35, '$R/r = ' + str(R/r)[:4] + '$', va='top', ha='left')
axs[1].text(-1.8, 1.8, '$R/r = ' + str(Rr/rr)[:4] + '$', va='top', ha='left')

for i in range(len(axs)):
    axs[i].set_xticks([])
    axs[i].set_yticks([])
    axs[i].set_aspect(1)
    
fig.subplots_adjust(wspace=.1, hspace=.1)

custom_lines = [Line2D([0], [0], color='C0'),
                Line2D([0], [0], color='C1', linestyle='dotted'),
                Line2D([0], [0], color='C1', linestyle='dashed'),
                Line2D([0], [0], marker='x', color='w', markersize=5, markeredgecolor='C1'),
               ]

fig.legend(custom_lines, ["Polytope", "Min inscribing ball", "Max inscribed ball", "Chebyshev center"], bbox_to_anchor=(0.5, 0.1), loc='upper center', ncol=2)
fig.savefig("img/sandwiching-ratio" + fmt, bbox_inches='tight')
plt.show()

# Proposals

# Thompson sampling

In [None]:
def kernel(x, y, l = 1, sigma = 1):
    sqdist = np.sum(x**2, 1).reshape(-1, 1) + np.sum(y**2, 1) - 2 * np.dot(x, y.T)
    return sigma**2 * np.exp(-0.5 / l**2 * sqdist)
  
def posterior(x, x_obs, y_obs, l = 1, sigma_f = 1, sigma_y = 0, prior_mu = 0):
    mu = np.array([prior_mu] * len(x))
    K_ss = kernel(x, x, l, sigma_f)
    cov = K_ss 
    
    if len(x_obs) > 0:
        K = kernel(x_obs, x_obs, l, sigma_f) + sigma_y * np.eye(len(x_obs))
        K_s = kernel(x_obs, x, l, sigma_f)
        K_inv = np.linalg.inv(K)
        
        mu = mu + K_s.T.dot(K_inv).dot(y_obs - prior_mu)
        cov -= K_s.T.dot(K_inv).dot(K_s) #+ sigma_y * np.eye(len(x))

    return mu, cov

np.random.seed(5)
def target(x):
    return np.exp(-x**2) + np.random.normal(0, .1)

x = np.linspace(-2, 2)
x = x[1:-1]
n = 3
n_sample_posterior_max = 200

truth = [np.exp(-x[i]**2) for i in range(len(x))]
x_obs, y_obs = [], []

fig, axs = plt.subplots(1, 3, figsize=(12, 3), dpi=dpi)

for i in range(n):
    mu, cov = posterior(x.reshape(-1, 1), np.array(x_obs).reshape(-1, 1), np.array(y_obs), sigma_y = .1)
    axs[i].plot(x, mu, zorder=0)
    axs[i].fill_between(x, mu-np.sqrt(np.diag(cov)), mu+np.sqrt(np.diag(cov)), alpha=.1, zorder=-20)
    
    u = np.random.multivariate_normal(mu, cov)
    axs[i].plot(x, u, color='C2', alpha=.5)
    
    x_obs.append(x[np.argmax(u)])
    y_obs.append(target(x_obs[-1]))
    
    axs[i].scatter(x_obs, y_obs, color='C1')
    axs[i].errorbar(x_obs, y_obs, yerr=0.1, color='C1', fmt='none')

    axs[i].plot([x_obs[-1], x_obs[-1]], [-2, 2], color='gray', linestyle='dashed', zorder=-10)

    axs[i].plot(x, truth, color='C1', alpha=.5, linestyle='dashed', zorder=10)
    axs[i].set_ylim([-1.4, 1.4])
    if i == 0:
        axs[i].set_yticks([-1, 0, 1])
        axs[i].set_ylabel('$y$')
    else:
        axs[i].set_yticks([])
        
    axs[i].set_xlabel('$x$')
        

fig.subplots_adjust(wspace=.05)
    
custom_lines = [Line2D([0], [0], color='C0'),
                Patch(color='C0', alpha=.1),
                Line2D([0], [0], color='C1', alpha=.5, linestyle='dashed'),
                Line2D([0], [0], marker='o', color='w', markersize=10, markerfacecolor='C1'),
                Line2D([0], [0], color='C2', alpha=.5),
                Line2D([0], [0], color='gray', linestyle='dashed'),
               ]

fig.legend(custom_lines, ["GP mean", "GP std. dev.", "Ground truth", "Data", "Acquisition function", "Next evaluation"], bbox_to_anchor=(0.5, 0), loc='upper center', ncol=len(custom_lines))
fig.savefig("img/ts-demo" + fmt, bbox_inches='tight')
plt.show()

# $n_{\mathrm{eff}}$ and tuning targets as functions of stepsizes

In [None]:
with open("data/bruteforce_data", "rb") as fhandle:
    results = dill.load(fhandle)

labelsize=6

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = len(proposals) + 1, len(problems) + 1
    fig, axs = plt.subplots(rows, cols, figsize=large_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [1] + [3]*(rows-1), "width_ratios": [1] + [3]*(cols-1)})

    fig.patch.set_facecolor('white')

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            if (problem, proposal) not in results or results[(problem, proposal)] is None: continue
            
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            ax1 = axs[j+1, i+1]
            ax2 = ax1.twinx()
            ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            ax2.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1
            
            
            # identify good convergence
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1))
            
            y[np.where(y > 20)] = 20

            y_mean = np.mean(y, axis=-1)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            is_good = (y_mean <= rhat_threshold)
            good, bad = collect_connected_sets(np.log10(stepsize_grid), is_good)

            for xa, xb in good:
                idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                ax2.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, **target_style[target])
                ax2.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, **target_style[target])

            ax2.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, **target_style[target])
            ax2.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, **target_style[target])

            if i == len(problems) - 1:
                ax2.set_yticks([1, 2, 3], [1, 2, 3])
            else:
                ax2.set_yticks([], [])
            
            
            for target_idx, target in enumerate(targets):
                if target in ["neff/t", "esjd/t", "t", "n", "rhat"]: continue

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1) 

                if target not in ["acc", "rhat"]:
                    good_max = np.max(y_mean * is_good)
                    good_max = good_max if good_max != 0 else np.max(y_mean)
                    y /= good_max
                    y_mean /= good_max

                y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                for xa, xb in good:
                    idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                    ax1.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, **target_style[target])
                    ax1.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, **target_style[target])

                ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, **target_style[target])
                ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, **target_style[target])

                
                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m
                    else:
                        x_opt_acc = None

            if i == 0:
                ax1.set_yticks([0, opt_acc, 1], [0, r"$\alpha^*$", 1])
            else:
                ax1.set_yticks([], [])
            ax1.plot([ax_a, ax_b], [opt_acc, opt_acc], linestyle='dashed', color='gray')
            
            if x_opt_acc is not None:
                ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3, x_opt_acc], [-5, -4, -3, -2, -1, 0, 1, 2, 3, r"$s_{\alpha}^*$"])
                ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
            
            ylim1 = [-.05, 1.1]
            ax1.set_xlim([ax_a, ax_b])
            ax1.set_ylim(ylim1)
            ylim2 = list(ax2.get_ylim())
            ylim2[1] = 3.1 # upper rhat, since rhat > 10 is freaking bad and just distorts the image
            p = np.abs(ylim1[0]) / (ylim1[1] - ylim1[0])
            ylim2[0] = 1 - p*(ylim2[1] - 1)
            ax2.set_ylim(ylim2)
            
            #ax1.set_ylabel('normalized statistic')
            if i+1 == len(problems):
                ax2.set_ylabel(r"$\hat{R}$")

            #ax1.legend()
            #ax2.legend()
            
    #handles, labels = axs[proposal_idx, i].get_legend_handles_labels()
    #fig.legend(handles, labels, loc=(.07, .07))
            
    for i in range(0, rows-1):
        axs[i+1, 0].axis('off')
        axs[i+1, 0].text(0, .5, list(proposals)[i].replace('Hit-And-Run', 'Hit & Run').replace('Rounding ', 'Rounding\n'), ha='center', va='center')

    for i in range(1, cols):
        axs[0, i].axis('off')
        axs[0, i].text(.5, 0, list(problems)[i-1], ha='center', va='center')
        axs[-1, i].set_xlabel("log step size", fontsize=labelsize, labelpad=0)

    axs[0, 0].axis('off')
    
    custom_lines = [Line2D([0], [0], **target_style['neff']),
                    Line2D([0], [0], **target_style['rhat']),
                    Line2D([0], [0], **target_style['acc']),
                    Line2D([0], [0], **target_style['esjd']),
                   ]

    #axs[rows-2, 1].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance Rate", "ESJD",], loc=(-.35, -.35))
    axs[rows-1, 3].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance rate", "ESJD",], bbox_to_anchor=(0.5, -0.5), loc='upper center', ncol=len(custom_lines))
    
    fig.subplots_adjust(hspace=.45, wspace=.05)
    fig.savefig("img/neff-bruteforce-results" + fmt, bbox_inches='tight')
    plt.show()

# $n_{\mathrm{eff}}/t$ and tuning targets as functions of stepsizes

In [None]:
with open("data/bruteforce_data", "rb") as fhandle:
    results = dill.load(fhandle)
    
alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = len(proposals) + 1, len(problems) + 1
    fig, axs = plt.subplots(rows, cols, figsize=large_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [1] + [3]*(rows-1), "width_ratios": [1] + [3]*(cols-1)})

    fig.patch.set_facecolor('white')

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            if (problem, proposal) not in results or results[(problem, proposal)] is None: continue
            
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            ax1 = axs[j+1, i+1]
            ax2 = ax1.twinx()
            ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            ax2.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1
            
            
            # identify good convergence
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1))
            
            y[np.where(y > 20)] = 20

            y_mean = np.mean(y, axis=-1)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            is_good = (y_mean <= rhat_threshold)
            good, bad = collect_connected_sets(np.log10(stepsize_grid), is_good)

            for xa, xb in good:
                idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                ax2.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, **target_style[target])
                ax2.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, **target_style[target])

            ax2.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, **target_style[target])
            ax2.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, **target_style[target])

            if i == len(problems) - 1:
                ax2.set_yticks([1, 2, 3], [1, 2, 3])
            else:
                ax2.set_yticks([], [])
            
            
            for target_idx, target in enumerate(targets):
                if target in ["neff", "esjd", "t", "n", "rhat"]: continue

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1) 

                if target not in ["acc", "rhat"]:
                    good_max = np.max(y_mean * is_good)
                    good_max = good_max if good_max != 0 else np.max(y_mean)
                    y /= good_max
                    y_mean /= good_max

                y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                for xa, xb in good:
                    idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                    ax1.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, **target_style[target])
                    ax1.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, **target_style[target])

                ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, **target_style[target])
                ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, **target_style[target])

                
                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m
                    else:
                        x_opt_acc = None

            if i == 0:
                ax1.set_yticks([0, opt_acc, 1], [0, r"$\alpha^*$", 1])
            else:
                ax1.set_yticks([], [])
            ax1.plot([ax_a, ax_b], [opt_acc, opt_acc], linestyle='dashed', color='gray')
            
            if x_opt_acc is not None:
                ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3, x_opt_acc], [-5, -4, -3, -2, -1, 0, 1, 2, 3, r"$s_{\alpha}^*$"])
                ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
            
            ylim1 = [-.05, 1.1]
            ax1.set_xlim([ax_a, ax_b])
            ax1.set_ylim(ylim1)
            ylim2 = list(ax2.get_ylim())
            ylim2[1] = 3.1 # upper constraint on rhat, since rhat > 3 is freaking bad and just distorts the image
            p = np.abs(ylim1[0]) / (ylim1[1] - ylim1[0])
            ylim2[0] = 1 - p*(ylim2[1] - 1)
            ax2.set_ylim(ylim2)
            
            #ax1.set_ylabel('normalized statistic')
            if i+1 == len(problems):
                ax2.set_ylabel(r"$\hat{R}$")

            #ax1.legend()
            #ax2.legend()
            
    #handles, labels = axs[proposal_idx, i].get_legend_handles_labels()
    #fig.legend(handles, labels, loc=(.07, .07))
            
    for i in range(rows-1):
        axs[i+1, 0].axis('off')
        axs[i+1, 0].text(0, .5, list(proposals)[i].replace('Hit-And-Run', 'Hit & Run').replace('Rounding ', 'Rounding\n'), ha='center', va='center')

    for i in range(1, cols):
        axs[0, i].axis('off')
        axs[0, i].text(.5, 0, list(problems)[i-1], ha='center', va='center')
        axs[-1, i].set_xlabel("log step size", fontsize=labelsize, labelpad=0)

    axs[0, 0].axis('off')
    
    custom_lines = [Line2D([0], [0], **target_style['neff']),
                    Line2D([0], [0], **target_style['rhat']),
                    Line2D([0], [0], **target_style['acc']),
                    Line2D([0], [0], **target_style['esjd']),
                   ]

    #axs[rows-2, 1].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance Rate", "ESJD",], loc=(-.35, -.35))
    axs[rows-1, 3].legend(custom_lines, [r"min $n_{\mathrm{eff}}/s$", r"max $\hat{R}$", "Acceptance rate", "ESJD/s",], bbox_to_anchor=(0.5, -0.5), loc='upper center', ncol=len(custom_lines))
    
    fig.subplots_adjust(hspace=.45, wspace=.05)
    fig.savefig("img/neff-t-bruteforce-results" + fmt, bbox_inches='tight')
    plt.show()

# $n_{\mathrm{eff}}/t$ and tuning targets as functions of stepsizes

In [None]:
with open("data/bruteforce_data", "rb") as fhandle:
    results = dill.load(fhandle)
    
alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = len(proposals) + 1, len(problems) + 1
    fig, axs = plt.subplots(rows, cols, figsize=large_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [1] + [3]*(rows-1), "width_ratios": [1] + [3]*(cols-1)})

    fig.patch.set_facecolor('white')

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            if (problem, proposal) not in results or results[(problem, proposal)] is None: continue
            
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            ax1 = axs[j+1, i+1]
            ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1
            
            
            # identify good convergence
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1))
            
            y[np.where(y > 20)] = 20

            y_mean = np.mean(y, axis=-1)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            is_good = (y_mean <= rhat_threshold)
            good, bad = collect_connected_sets(np.log10(stepsize_grid), is_good)
            
            for target_idx, target in enumerate(targets):
                if target in ["t", "n", "rhat"]: continue

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1) 

                if target not in ["acc", "rhat"]:
                    good_max = np.max(y_mean * is_good)
                    good_max = good_max if good_max != 0 else np.max(y_mean)
                    y /= good_max
                    y_mean /= good_max

                y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))
                
                linestyle = 'dashed' if target in ["neff/t", "esjd/t"] else 'solid'

                for xa, xb in good:
                    idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                    ax1.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, linestyle=linestyle, **target_style[target])
                    ax1.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, linestyle=linestyle, **target_style[target])

                ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, linestyle=linestyle, **target_style[target])
                ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, linestyle=linestyle, **target_style[target])

                
                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m
                    else:
                        x_opt_acc = None

            if i == 0:
                ax1.set_yticks([0, opt_acc, 1], [0, r"$\alpha^*$", 1])
            else:
                ax1.set_yticks([], [])
            ax1.plot([ax_a, ax_b], [opt_acc, opt_acc], linestyle='dashed', color='gray')
            
            if x_opt_acc is not None:
                ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3, x_opt_acc], [-5, -4, -3, -2, -1, 0, 1, 2, 3, r"$s_{\alpha}^*$"])
                ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
            
            ylim1 = [-.05, 1.1]
            ax1.set_xlim([ax_a, ax_b])
            ax1.set_ylim(ylim1)
            ylim2 = list(ax2.get_ylim())
            ylim2[1] = min(ylim2[1], 10) # upper constraint on rhat, since rhat > 10 is freaking bad and just distorts the image
            p = np.abs(ylim1[0]) / (ylim1[1] - ylim1[0])
            ylim2[0] = 1 - p*(ylim2[1] - 1)
            ax2.set_ylim(ylim2)
            
            #ax1.set_ylabel('normalized statistic')
            if i+1 == len(problems):
                ax2.set_ylabel(r"$\hat{R}$")

            #ax1.legend()
            #ax2.legend()
            
    #handles, labels = axs[proposal_idx, i].get_legend_handles_labels()
    #fig.legend(handles, labels, loc=(.07, .07))
            
    for i in range(rows-1):
        axs[i+1, 0].axis('off')
        axs[i+1, 0].text(0, .5, list(proposals)[i].replace('Hit-And-Run', 'Hit & Run').replace('Rounding ', 'Rounding\n'), ha='center', va='center')

    for i in range(1, cols):
        axs[0, i].axis('off')
        axs[0, i].text(.5, 0, list(problems)[i-1], ha='center', va='center')
        axs[-1, i].set_xlabel("log step size", fontsize=labelsize, labelpad=0)
        
    axs[0, 0].axis('off')
    
    custom_lines = [Line2D([0], [0], **target_style['neff']),
                    Line2D([0], [0], **target_style['neff/t'], linestyle='dashed'),
                    Line2D([0], [0], **target_style['esjd']),
                    Line2D([0], [0], **target_style['esjd/t'], linestyle='dashed'),
                    Line2D([0], [0], **target_style['acc']),
                   ]

    #axs[rows-2, 1].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance Rate", "ESJD",], loc=(-.35, -.35))
    axs[rows-1, 3].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"min $n_{\mathrm{eff}}/s$", "ESJD", "ESJD/s", "Acceptance rate",], bbox_to_anchor=(0.5, -0.5), loc='upper center', ncol=len(custom_lines))
    
    fig.subplots_adjust(hspace=.45, wspace=.05)
    fig.savefig("img/neff-vs-neff-t" + fmt, bbox_inches='tight')
    plt.show()

# Time costs

In [None]:
if 'results' not in locals():
    with open("data/bruteforce_data", "rb") as fhandle:
        results = dill.load(fhandle)
    
with open("data/time_gps", "rb") as fhandle:
    gps = dill.load(fhandle)
    
alpha=1
gauss_opt_acc = .234
mala_opt_acc = .44

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = len(proposals)+1, len(problems)+1
    fig, axs = plt.subplots(rows, cols, figsize=large_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [1] + [3]*(rows-1), "width_ratios": [1] + [3]*(cols-1)})

    fig.patch.set_facecolor('white')

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            ax1 = axs[j+1, i+1]
            ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            stepsize_grid = stepsize_grids[(problem, proposal)]
            
            ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1
            
            for target_idx, target in enumerate(targets):
                if target in ["rhat", "neff", "neff/t", "esjd", "esjd/t", "n"]: continue

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)

                y_mean = np.mean(y, axis=-1) 

                if target not in ["acc", "rhat"]:
                    y /= np.max(y_mean)
                    y_mean /= np.max(y_mean)

                if target == "t":
                    X, Y = np.array([np.log10(stepsize_grid)]*y.shape[-1]).T.flatten().reshape(-1, 1), y.flatten().reshape(-1, 1)

                    ax1.scatter(X, Y, s=3, alpha=.1, **target_style[target])

                    _x = np.linspace(-5, 3, 200)
                    mu, std = gps[(problem, proposal)][:,0], gps[(problem, proposal)][:,1]

                    ax1.plot(_x, mu, **target_style[target])
                    ax1.fill_between(_x, mu - std, mu + std, alpha=.5, **target_style[target])
                else:
                    y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                    ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=alpha, **target_style[target])
                    ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.2*alpha, **target_style[target])                        

                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m

                            #ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
                            #ax1.plot([a, b], [opt_acc, opt_acc], linestyle='dashed', color='gray')
                            
            if i == 0:
                ax1.set_yticks([0, 1], [0, 1])
            else:
                ax1.set_yticks([], [])

                
            #ax1.set_yticks([0, .5, 1], [0, .5, 1])
            #ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3], [-5, -4, -3, -2, -1, 0, 1, 2, 3])
            
            ylim1 = [-.05, 1.1]
            #ax1.set_xlim([a, b])
            ax1.set_ylim(ylim1)
            
            
    for i in range(rows-1):
        axs[i+1, 0].axis('off')
        axs[i+1, 0].text(0, .5, list(proposals)[i].replace('Hit-And-Run', 'Hit & Run').replace('Rounding ', 'Rounding\n'), ha='center', va='center')

    for i in range(1, cols):
        axs[0, i].axis('off')
        axs[0, i].text(.5, 0, list(problems)[i-1], ha='center', va='center')
        axs[-1, i].set_xlabel("log step size", fontsize=labelsize, labelpad=0)

    axs[0, 0].axis('off')
    
    custom_lines = [Line2D([0], [0], **target_style['t']),
                    Line2D([0], [0], **target_style['acc']),
                   ]

    axs[rows-1, 3].legend(custom_lines, ["T/n", "Acceptance rate",], bbox_to_anchor=(0.5, -0.5), loc='upper center', ncol=len(custom_lines))
    
    fig.subplots_adjust(hspace=.45, wspace=.05)
    fig.savefig("img/time-costs" + fmt, bbox_inches='tight')
    plt.show()


# $n_{\mathrm{eff}}$ of optimal stepsize w.r.t ESJD and acceptance rate

In [None]:
group_by_proposal = False
normalize = True

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = 1, 1
    fig, ax = plt.subplots(rows, cols, figsize=1.5*broad_figsize, dpi=dpi)#, gridspec_kw={"height_ratios": [5]*(rows-1) + [1], "width_ratios": [1] + [5]*(cols-1)})

    fig.patch.set_facecolor('white')
    
    x = []
    neff_accs = []
    neff_accs_err = []
    good_accs = []
    
    neff_esjds = []
    neff_esjds_err = []
    good_esjds = []

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            if group_by_proposal:
                x.append((len(problems) + 2)*j + i)
            else:
                x.append((len(proposals) + 2)*i + j)

                
            # extract neff as final performance measure
            target = "neff"
            target_idx = list(targets).index(target)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)

            y_mean = np.mean(y, axis=-1) 

            if normalize:
                y /= np.max(y_mean)
                y_mean /= np.max(y_mean)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            neff = y_mean
            neff_err = y_err
            
            
            # extract rhat to determine goodness
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)
            y_mean = np.mean(y, axis=-1) 
            rhat = y_mean

            
            for target in ["acc", "esjd"]:
                target_idx = list(targets).index(target)

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1)

                if target == "esjd":
                    # determine stepsize that optimizes esjd and find its neff and rhat value.
                    y /= np.max(y_mean)
                    y_mean /= np.max(y_mean)

                    opt_esjd_idx = np.argmax(y_mean)
                    x_opt_esjd = np.log10(stepsize_grid)[opt_esjd_idx]

                    neff_esjds.append(neff[opt_esjd_idx])
                    neff_esjds_err.append((neff_err[0][opt_esjd_idx], neff_err[1][opt_esjd_idx]))
                    good_esjds.append(rhat[opt_esjd_idx] < rhat_threshold)

                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        # do linear interpolation between consecutive stepsizes between which the optimal acceptance rate is found
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m

                        # again linear interpolation to find the corresponding neff
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = neff[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = neff[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        neff_acc = m * x_opt_acc + c

                        neff_acc_err = [0, 0]
                        for j in range(len(neff_err)):
                            y1 = neff_err[j][last_pos_idx]
                            y2 = neff_err[j][last_pos_idx+1]

                            m = (y2 - y1) / (x2 - x1)
                            c = y1 - m * x1

                            neff_acc_err[j] = m * x_opt_acc + c

                        y1 = rhat[last_pos_idx]
                        y2 = rhat[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1
                        
                        good_acc = (m * x_opt_acc + c) < rhat_threshold

                    else:
                        neff_acc = np.nan
                        neff_acc_err = [np.nan, np.nan]
                        good_acc = False

                    neff_accs.append(neff_acc)
                    neff_accs_err.append(neff_acc_err)
                    good_accs.append(good_acc)

    x = np.array(x)
    
    fontsize = 7
    
    a = .2 # alpha
    d = .06

    neff_esjds_err = np.array((np.abs(np.array(neff_esjds_err)[:,0] - neff_esjds), np.abs(np.array(neff_esjds_err)[:,1] - neff_esjds)))
    neff_accs_err = np.array((np.abs(np.array(neff_accs_err)[:,0] - neff_accs), np.abs(np.array(neff_accs_err)[:,1] - neff_accs)))
    
    good_esjds = np.array(good_esjds)
    good_accs = np.array(good_accs)
    
    for i, _x in enumerate(x):
        esjd_facecolors = 'none' if not good_esjds[i] else 'C0'
        acc_facecolors = 'none' if not good_accs[i] else 'C1'
        if i == 0:
            ax.scatter(_x+d, neff_esjds[i], zorder=-1, color='C0', 
                                            label=r'min $n_{\mathrm{eff}}(\mathrm{ESJD}^*)$', 
                                            facecolors=esjd_facecolors,
                                            marker='o', alpha=(1-a)*good_esjds[i]+a)
            ax.scatter(_x-d, neff_accs[i], zorder=1, color='C1', 
                                           label=r'min $n_{\mathrm{eff}}(\alpha^*)$', 
                                           facecolors=acc_facecolors,
                                           marker='o', alpha=(1-a)*good_accs[i]+a)
        else:
            ax.scatter(_x+d, neff_esjds[i], zorder=-1, color='C0', 
                                            facecolors=esjd_facecolors,
                                            marker='o', alpha=(1-a)*good_esjds[i]+a)
            ax.scatter(_x-d, neff_accs[i], zorder=1, color='C1', 
                                           facecolors=acc_facecolors,
                                           marker='o', alpha=(1-a)*good_accs[i]+a)    
    for i, _x in enumerate(x):
        neff_max = max(neff_esjds[i], neff_accs[i])
        ax.plot([_x, _x], [0, neff_max], linestyle='dashed', linewidth=.1, color='black', alpha=.5, zorder=-10)
    
    good = np.where(~np.isnan(neff_accs))
    
    for i, _x in enumerate(x):
        ax.errorbar(_x+d, neff_esjds[i], yerr=neff_esjds_err[:,i:i+1], zorder=-1, fmt='none', color='C0', alpha=(1-a)*good_esjds[i]+a)
        ax.errorbar(_x-d, neff_accs[i], yerr=neff_accs_err[:,i:i+1], zorder=1, fmt='none', color='C1', alpha=(1-a)*good_accs[i]+a)
    
    if normalize:
        ax.set_ylim((-0.02082685768576953, 1.0486108027469414))
    
    ax.legend(bbox_to_anchor=(0.5, -0.4), loc='upper center', ncol=2)
    
    ub = ax.get_ylim()[1] / 1.05
    
    if group_by_proposal:
        x_group = [((len(problems) - 1) / 2) + (len(problems) + 2) * i for i in range(len(proposals))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(proposals)[i], ha='center', va='center')
            
        xlabels = [name for name in problems] * len(proposals)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
    else:
        x_group = [((len(problems)-1) / 2) + (len(proposals) + 2) * i for i in range(len(problems))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(problems)[i], ha='center', va='center')
            
        xlabels = [name for name in proposals] * len(problems)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
            
    fig.savefig("img/neff-results" + fmt, bbox_inches='tight')
    plt.show()


# $n_{\mathrm{eff}}/t$ of optimal stepsize w.r.t ESJD/t and acceptance rate

In [None]:
group_by_proposal = False
normalize = True

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .44

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = 1, 1
    fig, ax = plt.subplots(rows, cols, figsize=1.5*broad_figsize, dpi=dpi)#, gridspec_kw={"height_ratios": [5]*(rows-1) + [1], "width_ratios": [1] + [5]*(cols-1)})

    fig.patch.set_facecolor('white')
    
    x = []
    neff_accs = []
    neff_accs_err = []
    good_accs = []
    
    neff_esjds = []
    neff_esjds_err = []
    good_esjds = []

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            if group_by_proposal:
                x.append((len(problems) + 2)*j + i)
            else:
                x.append((len(proposals) + 2)*i + j)

            target = "neff/t"
            target_idx = list(targets).index(target)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)

            y_mean = np.mean(y, axis=-1) 

            if normalize:
                y /= np.max(y_mean)
                y_mean /= np.max(y_mean)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            neff = y_mean
            neff_err = y_err
            
            
            # extract rhat to determine goodness
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)
            y_mean = np.mean(y, axis=-1) 
            rhat = y_mean

            
            for target in ["acc", "esjd/t"]:
                target_idx = list(targets).index(target)

                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1)

                if target == "esjd/t":
                    y /= np.max(y_mean)
                    y_mean /= np.max(y_mean)

                    opt_esjd_idx = np.argmax(y_mean)
                    x_opt_esjd = np.log10(stepsize_grid)[opt_esjd_idx]
                    
                    neff_esjds.append(neff[opt_esjd_idx])
                    neff_esjds_err.append((neff_err[0][opt_esjd_idx], neff_err[1][opt_esjd_idx]))
                    good_esjds.append(rhat[opt_esjd_idx] < rhat_threshold)

                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m

                        
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = neff[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = neff[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        neff_acc = m * x_opt_acc + c

                        
                        neff_acc_err = [0, 0]
                        for j in range(len(neff_err)):
                            y1 = neff_err[j][last_pos_idx]
                            y2 = neff_err[j][last_pos_idx+1]

                            m = (y2 - y1) / (x2 - x1)
                            c = y1 - m * x1

                            neff_acc_err[j] = m * x_opt_acc + c

                            
                        y1 = rhat[last_pos_idx]
                        y2 = rhat[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1
                        
                        good_acc = (m * x_opt_acc + c) < rhat_threshold

                    else:
                        neff_acc = np.nan
                        neff_acc_err = [np.nan, np.nan]
                        good_acc = False

                    neff_accs.append(neff_acc)
                    neff_accs_err.append(neff_acc_err)
                    good_accs.append(good_acc)

    x = np.array(x)
    
    fontsize = 7
    
    a = .2 # alpha
    d = .06

    neff_esjds_err = np.array((np.abs(np.array(neff_esjds_err)[:,0] - neff_esjds), np.abs(np.array(neff_esjds_err)[:,1] - neff_esjds)))
    neff_accs_err = np.array((np.abs(np.array(neff_accs_err)[:,0] - neff_accs), np.abs(np.array(neff_accs_err)[:,1] - neff_accs)))
    
    good_esjds = np.array(good_esjds)
    good_accs = np.array(good_accs)
    
    for i, _x in enumerate(x):
        esjd_facecolors = 'none' if not good_esjds[i] else 'C0'
        acc_facecolors = 'none' if not good_accs[i] else 'C1'
        if i == 0:
            ax.scatter(_x+d, neff_esjds[i], zorder=-1, color='C0', 
                                            label=r'min $n_{\mathrm{eff}}/s(\mathrm{ESJD}/s^*)$', 
                                            facecolors=esjd_facecolors,
                                            marker='o', alpha=(1-a)*good_esjds[i]+a)
            ax.scatter(_x-d, neff_accs[i], zorder=1, color='C1', 
                                           label=r'min $n_{\mathrm{eff}}/s(\alpha^*)$', 
                                           facecolors=acc_facecolors,
                                           marker='o', alpha=(1-a)*good_accs[i]+a)
        else:
            ax.scatter(_x+d, neff_esjds[i], zorder=-1, color='C0', 
                                            facecolors=esjd_facecolors,
                                            marker='o', alpha=(1-a)*good_esjds[i]+a)
            ax.scatter(_x-d, neff_accs[i], zorder=1, color='C1', 
                                           facecolors=acc_facecolors,
                                           marker='o', alpha=(1-a)*good_accs[i]+a)
    
    for i, _x in enumerate(x):
        neff_max = max(neff_esjds[i], neff_accs[i])
        ax.plot([_x, _x], [0, neff_max], linestyle='dashed', linewidth=.1, color='black', alpha=.5, zorder=-10)
    
    good = np.where(~np.isnan(neff_accs))
    
    for i, _x in enumerate(x):
        ax.errorbar(_x+d, neff_esjds[i], yerr=neff_esjds_err[:,i:i+1], zorder=-1, fmt='none', color='C0', alpha=(1-a)*good_esjds[i]+a)
        ax.errorbar(_x-d, neff_accs[i], yerr=neff_accs_err[:,i:i+1], zorder=1, fmt='none', color='C1', alpha=(1-a)*good_accs[i]+a)
    
    if normalize:
        ax.set_ylim((-0.02082685768576953, 1.0486108027469414))
    
    #ax.bar(x, neff_accs, zorder=1, width=.25)
    #ax.bar(np.array(x)+.25, neff_esjds, zorder=-1, width=.25)
    
    ax.legend(bbox_to_anchor=(0.5, -0.4), loc='upper center', ncol=2)
    
    ub = ax.get_ylim()[1] / 1.05
    
    if group_by_proposal:
        x_group = [((len(problems) - 1) / 2) + (len(problems) + 2) * i for i in range(len(proposals))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(proposals)[i], ha='center', va='center')
            
        xlabels = [name for name in problems] * len(proposals)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
    else:
        x_group = [((len(problems)-1) / 2) + (len(proposals) + 2) * i for i in range(len(problems))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(problems)[i], ha='center', va='center')
            
        xlabels = [name for name in proposals] * len(problems)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
            
    fig.savefig("img/neff-t-results" + fmt, bbox_inches='tight')
    plt.show()


# Tuning results

In [None]:
with open("data/bruteforce_data", "rb") as fhandle:
    results = dill.load(fhandle)
    
with open("data/tuning_data", "rb") as fhandle:
    tuning_results = dill.load(fhandle)    

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

tuning_stepsize_grid = np.linspace(np.log10(ts_params['lower_bound']), np.log10(ts_params['upper_bound']), ts_params['grid_size'])

with warnings.catch_warnings():
    warnings.simplefilter('ignore')

    for target in ["acc", "esjd", "esjd/t", "neff", "neff/t"]:
        rows, cols = len(proposals) + 1, len(problems) + 1
        fig, axs = plt.subplots(rows, cols, figsize=large_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [1] + [3]*(rows-1), "width_ratios": [1] + [3]*(cols-1)})

        fig.patch.set_facecolor('white')

        for i, problem in enumerate(problems):
            for j, proposal in enumerate(proposals):
                if (problem, proposal) not in results or results[(problem, proposal)] is None: continue

                opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc

                ax1 = axs[j+1, i+1]
                ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)
    
                stepsize_grid = stepsize_grids[(problem, proposal)]

                ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1


                # identify good convergence
                target_idx = list(targets).index("rhat")

                y = deepcopy(results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1))

                y[np.where(y > 20)] = 20

                y_mean = np.mean(y, axis=-1)

                y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                is_good = (y_mean <= rhat_threshold)
                good, bad = collect_connected_sets(np.log10(stepsize_grid), is_good)


                corresponding_tuning_targets = []
                
                for tuning_target in tuning_results[(problem, proposal)]:
                    if target_map[tuning_target] == target:
                        corresponding_tuning_targets += [tuning_target]
                        stepsizes, reg_mean, reg_err, obs = tuning_results[(problem, proposal)][tuning_target]
                        for k, _ in enumerate(reg_mean):
                            z = 1./np.nanmax(reg_mean[k] + np.sqrt(reg_err[k]))
                            #print(reg_mean[k] + np.sqrt(reg_err[k]))
                            ax1.scatter(obs[k][:,0], z*obs[k][:,1], **tuning_target_style[tuning_target], s=3, alpha=1, zorder=1)
                            ax1.errorbar(obs[k][:,0], z*obs[k][:,1], yerr=z*np.sqrt(obs[k][:,2]), fmt='none', **tuning_target_style[tuning_target], alpha=.1)
                            
                            ax1.plot(tuning_stepsize_grid, z*reg_mean[k], **tuning_target_style[tuning_target], alpha=.2)
                            ax1.fill_between(tuning_stepsize_grid, z*reg_mean[k]-z*np.sqrt(reg_err[k]), z*reg_mean[k]+z*np.sqrt(reg_err[k]), **tuning_target_style[tuning_target], alpha=.1)

                            
                target_idx = list(targets).index(target)
                            
                y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                y_mean = np.mean(y, axis=-1) 

                if target not in ["acc", "rhat"]:
                    good_max = np.max(y_mean * is_good)
                    good_max = good_max if good_max != 0 else np.max(y_mean)
                    y /= good_max
                    y_mean /= good_max

                y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                for xa, xb in good:
                    idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                    ax1.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, color='C0', zorder=10)
                    ax1.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, color='C0')

                ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, color='C0', zorder=10)
                ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, color='C0')


                if target == "acc":
                    diff = y_mean - opt_acc
                    if diff[0] > 0 and diff[-1] < 0:
                        last_pos_idx = np.where(diff > 0)[-1][-1]
                        x1 = np.log10(stepsize_grid)[last_pos_idx]
                        y1 = y_mean[last_pos_idx]
                        x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                        y2 = y_mean[last_pos_idx+1]

                        m = (y2 - y1) / (x2 - x1)
                        c = y1 - m * x1

                        x_opt_acc = (opt_acc - c) / m
                    else:
                        x_opt_acc = None

                    if i == 0:
                        ax1.set_yticks([0, opt_acc, 1], [0, r"$\alpha^*$", 1])
                    else:
                        ax1.set_yticks([], [])
                    ax1.plot([ax_a, ax_b], [opt_acc, opt_acc], linestyle='dashed', color='gray')

                    if x_opt_acc is not None:
                        ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3, x_opt_acc], [-5, -4, -3, -2, -1, 0, 1, 2, 3, r"$s_{\alpha}^*$"])
                        ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
                else:
                    if i == 0:
                        ax1.set_yticks([0, 1], [0, 1])
                    else:
                        ax1.set_yticks([], [])
                    ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3], [-5, -4, -3, -2, -1, 0, 1, 2, 3])

                ylim1 = [-.05, 1.1]
                ax1.set_xlim([ax_a, ax_b])
                ax1.set_ylim(ylim1)
                


        for i in range(rows-1):
            axs[i+1, 0].axis('off')
            axs[i+1, 0].text(0, .5, list(proposals)[i].replace('Hit-And-Run', 'Hit & Run').replace('Rounding ', 'Rounding\n'), ha='center', va='center')

        for i in range(1, cols):
            axs[0, i].axis('off')
            axs[0, i].text(.5, 0, list(problems)[i-1], ha='center', va='center')
            axs[-1, i].set_xlabel("log step size", fontsize=labelsize, labelpad=0)

        axs[0, 0].axis('off')

        display_names = []
        custom_lines = [Line2D([0], [0], color='C0')]
        for tuning_target in corresponding_tuning_targets:
            custom_lines.append(Line2D([0], [0], **tuning_target_style[tuning_target]))
            _tuning = " tuning)" if tuning_target[-1] == ")" else " (tuning)"
            display_names.append(tuning_target.replace('\n', ' ').replace('Rate', 'rate').replace(')', ',') + _tuning)

        #axs[rows-2, 1].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance Rate", "ESJD",], loc=(-.35, -.35))
        axs[rows-1, 3].legend(custom_lines, [target_display_names[target] + " (ground truth)"] + display_names, bbox_to_anchor=(0.5, -0.5), loc='upper center', ncol=len(display_names)+1)
        
        fig.subplots_adjust(hspace=.45, wspace=.05)
        fig.savefig("img/" + target.replace('/', '-') + "-tuning-result" + fmt, bbox_inches='tight')
        plt.show()

# Tuning posteriors overview

In [None]:
with open("data/bruteforce_data", "rb") as fhandle:
    results = dill.load(fhandle)
    
with open("data/tuning_data", "rb") as fhandle:
    tuning_results = dill.load(fhandle)    

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

consider_problems = ['Gauss', 'STAT-1', 'STAT-2', 'STAT-2-ni']

tuning_stepsize_grid = np.linspace(np.log10(ts_params['lower_bound']), np.log10(ts_params['upper_bound']), ts_params['grid_size'])

with warnings.catch_warnings():
    warnings.simplefilter('ignore')

    for target_group in [("acc", "esjd", "neff", "esjd/t", "neff/t")]:
        rows, cols = len(target_group) + 1, len(consider_problems)
        fig, axs = plt.subplots(rows, cols, figsize=mid_figsize, dpi=large_dpi, gridspec_kw={"height_ratios": [3]*(rows-1) + [1], "width_ratios": [3]*cols})

        fig.patch.set_facecolor('white')

        for j, target in enumerate(target_group):
            for i, problem in enumerate(['Gauss', 'STAT-1', 'STAT-2', 'STAT-2-ni']):
                for proposal in ['Rounding Gaussian']:
                    #j = list(proposals).index(proposal)
                    if (problem, proposal) not in results or results[(problem, proposal)] is None: continue

                    opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc

                    ax1 = axs[j, i]
                    ax1.tick_params(axis='both', which='major', labelsize=labelsize, length=1)

                    stepsize_grid = stepsize_grids[(problem, proposal)]

                    ax_a, ax_b = np.log10(stepsize_grid[0])-.1, np.log10(stepsize_grid[-1])+.1


                    # identify good convergence
                    target_idx = list(targets).index("rhat")

                    y = deepcopy(results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1))

                    y[np.where(y > 20)] = 20

                    y_mean = np.mean(y, axis=-1)

                    y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                    is_good = (y_mean <= rhat_threshold)
                    good, bad = collect_connected_sets(np.log10(stepsize_grid), is_good)


                    corresponding_tuning_targets = []

                    for tuning_target in tuning_results[(problem, proposal)]:
                        if target_map[tuning_target] == target:
                            corresponding_tuning_targets += [tuning_target]
                            stepsizes, reg_mean, reg_err, obs = tuning_results[(problem, proposal)][tuning_target]
                            for k, _ in enumerate(reg_mean):
                                z = 1./np.nanmax(reg_mean[k] + np.sqrt(reg_err[k]))
                                #print(reg_mean[k] + np.sqrt(reg_err[k]))
                                ax1.scatter(obs[k][:,0], z*obs[k][:,1], **tuning_target_style[tuning_target], s=3, alpha=1, zorder=1)
                                ax1.errorbar(obs[k][:,0], z*obs[k][:,1], yerr=z*np.sqrt(obs[k][:,2]), fmt='none', **tuning_target_style[tuning_target], alpha=.1)

                                ax1.plot(tuning_stepsize_grid, z*reg_mean[k], **tuning_target_style[tuning_target], alpha=.2)
                                ax1.fill_between(tuning_stepsize_grid, z*reg_mean[k]-z*np.sqrt(reg_err[k]), z*reg_mean[k]+z*np.sqrt(reg_err[k]), **tuning_target_style[tuning_target], alpha=.1)


                    target_idx = list(targets).index(target)

                    y = results[(problem, proposal)][target_idx].reshape(len(stepsize_grid), -1)
                    y_mean = np.mean(y, axis=-1) 

                    if target not in ["acc", "rhat"]:
                        good_max = np.max(y_mean * is_good)
                        good_max = good_max if good_max != 0 else np.max(y_mean)
                        y /= good_max
                        y_mean /= good_max

                    y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

                    for xa, xb in good:
                        idx = sorted(choose_good(np.log10(stepsize_grid), (xa, xb))[0])
                        ax1.plot(np.log10(stepsize_grid)[idx], y_mean[idx], label=target, color='C0', zorder=10)
                        ax1.fill_between(np.log10(stepsize_grid)[idx], y_err[0][idx], y_err[1][idx], alpha=.1, color='C0')

                    ax1.plot(np.log10(stepsize_grid), y_mean, label=target, alpha=.2, color='C0', zorder=10)
                    ax1.fill_between(np.log10(stepsize_grid), y_err[0], y_err[1], alpha=.1, color='C0')


                    if target == "acc":
                        diff = y_mean - opt_acc
                        if diff[0] > 0 and diff[-1] < 0:
                            last_pos_idx = np.where(diff > 0)[-1][-1]
                            x1 = np.log10(stepsize_grid)[last_pos_idx]
                            y1 = y_mean[last_pos_idx]
                            x2 = np.log10(stepsize_grid)[last_pos_idx+1]
                            y2 = y_mean[last_pos_idx+1]

                            m = (y2 - y1) / (x2 - x1)
                            c = y1 - m * x1

                            x_opt_acc = (opt_acc - c) / m
                        else:
                            x_opt_acc = None


                        if i == 0:
                            ax1.set_yticks([0, opt_acc, 1], [0, r"$\alpha^*$", 1])
                        else:
                            ax1.set_yticks([], [])
                        ax1.plot([ax_a, ax_b], [opt_acc, opt_acc], linestyle='dashed', color='gray')

                        if x_opt_acc is not None:
                            ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3, x_opt_acc], [-5, -4, -3, -2, -1, 0, 1, 2, 3, r"$s_{\alpha}^*$"])
                            ax1.plot([x_opt_acc, x_opt_acc], [-.1, 1.1], linestyle='dashed', color='gray')
                    else:
                        if i == 0:
                            ax1.set_yticks([0, 1], [0, 1])
                        else:
                            ax1.set_yticks([], [])
                        ax1.set_xticks([-5, -4, -3, -2, -1, 0, 1, 2, 3], [-5, -4, -3, -2, -1, 0, 1, 2, 3])
                        
                    ylim1 = [-.05, 1.1]
                    ax1.set_xlim([ax_a, ax_b])
                    ax1.set_ylim(ylim1)

            display_names = []
            custom_lines = [Line2D([0], [0], color='C0')]

            for tuning_target in corresponding_tuning_targets:
                custom_lines.append(Line2D([0], [0], **tuning_target_style[tuning_target]))
                _tuning = " tuning)" if tuning_target[-1] == ")" else " (tuning)"
                display_names.append(tuning_target.replace("Acceptance", "Acc.").replace('\n', ' ').replace('Rate', 'rate').replace(')', ',') + _tuning)

            #axs[rows-2, 1].legend(custom_lines, [r"min $n_{\mathrm{eff}}$", r"max $\hat{R}$", "Acceptance Rate", "ESJD",], loc=(-.35, -.35))
            
            axs[j, -1].legend(custom_lines, [target_display_names[target].replace("Acceptance", "Acc.") + " (ground truth)"] + display_names, bbox_to_anchor=(1.01, 1.1), loc='upper left')

        _problems = ['Gauss', 'STAT-1', 'STAT-2', 'STAT-2-ni']
        for i in range(0, cols):
            axs[-1, i].axis('off')
            axs[-1, i].text(.5, 0, _problems[i], ha='center', va='center')
        
        fig.subplots_adjust(hspace=.45, wspace=.05)
        fig.savefig("img/tuning-result-gaussian-rounding" + fmt, bbox_inches='tight')
        plt.show()

# $n_{\mathrm{eff}}$ of optimal stepsize w.r.t ESJD and acceptance rate

In [None]:
group_by_proposal = False
normalize = True

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = 1, 1
    fig, ax = plt.subplots(rows, cols, figsize=1.5*broad_figsize, dpi=dpi)#, gridspec_kw={"height_ratios": [5]*(rows-1) + [1], "width_ratios": [1] + [5]*(cols-1)})

    fig.patch.set_facecolor('white')
    
    x = []
    achieved_neff = {}
    achieved_neff_err = {}
    good_neff = {}

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            if group_by_proposal:
                x.append((len(problems) + 2)*j + i)
            else: 
                x.append((len(proposals) + 2)*i + j)

                
            # extract neff as final performance measure
            target = "neff"
            target_idx = list(targets).index(target)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)

            y_mean = np.mean(y, axis=-1) 

            if normalize:
                y /= np.max(y_mean)
                y_mean /= np.max(y_mean)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            neff = y_mean
            neff_err = y_err
            
            
            # extract rhat to determine goodness
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)
            y_mean = np.mean(y, axis=-1) 
            rhat = y_mean

            
            for target in ["esjd", "neff", "acc"]:
                corresponding_tuning_targets = []
                
                for tuning_target in tuning_results[(problem, proposal)]:
                    if target_map[tuning_target] == target:
                        if tuning_target not in achieved_neff:
                            achieved_neff[tuning_target] = []
                            achieved_neff_err[tuning_target] = []
                            good_neff[tuning_target] = []

                        corresponding_tuning_targets += [tuning_target]
                        stepsizes, reg_mean, reg_err, obs = tuning_results[(problem, proposal)][tuning_target]
                        
                        _neffs = np.array([neff[np.argmin(np.abs(stepsize_grid - stepsize))] for stepsize in stepsizes])
                        mean_idx = np.argmin(np.abs(stepsize_grid - np.mean(stepsizes)))
                        
                        achieved_neff[tuning_target].append(np.mean(_neffs))
                        achieved_neff_err[tuning_target].append([np.min(_neffs), np.max(_neffs)])
                        good_neff[tuning_target].append(rhat[mean_idx] < rhat_threshold)
                        
    x = np.array(x)
    
    fontsize = 7
    
    a = .2 # alpha
    d = .06

    #neff_esjds_err = np.array((np.abs(np.array(neff_esjds_err)[:,0] - neff_esjds), np.abs(np.array(neff_esjds_err)[:,1] - neff_esjds)))
    #neff_accs_err = np.array((np.abs(np.array(neff_accs_err)[:,0] - neff_accs), np.abs(np.array(neff_accs_err)[:,1] - neff_accs)))
    
    #good_esjds = np.array(good_esjds)
    #good_accs = np.array(good_accs)
    
    for k, tuning_target in enumerate(achieved_neff):
        err = np.array((
            np.abs(np.array(achieved_neff_err[tuning_target])[:,0] - achieved_neff[tuning_target]), 
            np.abs(np.array(achieved_neff_err[tuning_target])[:,1] - achieved_neff[tuning_target])
        ))
        good = np.array(good_neff[tuning_target])
        
        _d = (k - len(achieved_neff) / 2) * d
        
        for i, _x in enumerate(x):
            facecolors = 'none' if not good[i] else cm.tab20(k)
            if i == 0:
                ax.scatter(_x+_d, achieved_neff[tuning_target][i], zorder=1, color=cm.tab20(k), 
                                                                   label=tuning_target.replace('Rate\n', 'rate '), 
                                                                   marker='o', 
                                                                   alpha=(1-a)*good[i]+a, facecolors=facecolors)
            else:
                ax.scatter(_x+_d, achieved_neff[tuning_target][i], zorder=1, color=cm.tab20(k), 
                                                                   marker='o', 
                                                                   alpha=(1-a)*good[i]+a, facecolors=facecolors)

        #good = np.where(~np.isnan(neff_accs))

        for i, _x in enumerate(x):
            ax.errorbar(_x+_d, achieved_neff[tuning_target][i], yerr=err[:,i:i+1], zorder=-1, fmt='none', color=cm.tab20(k), alpha=(1-a)*good[i]+a)

    for i, _x in enumerate(x):
        _neffs = []
        for tuning_target in achieved_neff:
            _neffs.append(achieved_neff[tuning_target][i])
        neff_max = max(_neffs)
        ax.plot([_x, _x], [0, neff_max], linestyle='dashed', linewidth=.1, color='black', alpha=.5, zorder=-10)

    if normalize:
        ax.set_ylim((-0.02082685768576953, 1.0486108027469414))
    
    ax.legend(bbox_to_anchor=(0.5, -0.4), loc='upper center', ncol=len(achieved_neff))
    
    ub = ax.get_ylim()[1] / 1.05
    
    if group_by_proposal:
        x_group = [((len(problems) - 1) / 2) + (len(problems) + 2) * i for i in range(len(proposals))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(proposals)[i], ha='center', va='center')
            
        xlabels = [name for name in problems] * len(proposals)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
    else:
        x_group = [((len(problems)-1) / 2) + (len(proposals) + 2) * i for i in range(len(problems))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(problems)[i], ha='center', va='center')
            
        xlabels = [name.replace('-And-', ' & ') for name in proposals] * len(problems)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
            
    fig.savefig("img/neff-tuning-results" + fmt, bbox_inches='tight')
    plt.show()


# $n_{\mathrm{eff}}/t$ of optimal stepsize w.r.t ESJD/t and acceptance rate

In [None]:
group_by_proposal = False
normalize = True

alpha=1
gauss_opt_acc = .234
mala_opt_acc = .574

with warnings.catch_warnings():
    warnings.simplefilter('ignore')
    
    rows, cols = 1, 1
    fig, ax = plt.subplots(rows, cols, figsize=1.5*broad_figsize, dpi=dpi)#, gridspec_kw={"height_ratios": [5]*(rows-1) + [1], "width_ratios": [1] + [5]*(cols-1)})

    fig.patch.set_facecolor('white')
    
    x = []
    achieved_neff = {}
    achieved_neff_err = {}
    good_neff = {}

    for i, problem in enumerate(problems):
        for j, proposal in enumerate(proposals):
            opt_acc = mala_opt_acc if proposal == 'CSmMALA' else gauss_opt_acc
            
            if group_by_proposal:
                x.append((len(problems) + 2)*j + i)
            else: 
                x.append((len(proposals) + 2)*i + j)

                
            # extract neff as final performance measure
            target = "neff/t"
            target_idx = list(targets).index(target)
            
            stepsize_grid = stepsize_grids[(problem, proposal)]

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)

            y_mean = np.mean(y, axis=-1) 

            if normalize:
                y /= np.max(y_mean)
                y_mean /= np.max(y_mean)

            y_err = (np.quantile(y, 0.05, axis=-1), np.quantile(y, 0.95, axis=-1))

            neff = y_mean
            neff_err = y_err
            
            
            # extract rhat to determine goodness
            target = "rhat"
            target_idx = list(targets).index(target)

            y = deepcopy(results[(problem, proposal)][target_idx]).reshape(len(stepsize_grid), -1)
            y_mean = np.mean(y, axis=-1) 
            rhat = y_mean

            
            for target in ["esjd/t", "neff/t", "acc"]:
                corresponding_tuning_targets = []
                
                for tuning_target in tuning_results[(problem, proposal)]:
                    if target_map[tuning_target] == target:
                        if tuning_target not in achieved_neff:
                            achieved_neff[tuning_target] = []
                            achieved_neff_err[tuning_target] = []
                            good_neff[tuning_target] = []

                        corresponding_tuning_targets += [tuning_target]
                        stepsizes, reg_mean, reg_err, obs = tuning_results[(problem, proposal)][tuning_target]
                        
                        _neffs = np.array([neff[np.argmin(np.abs(stepsize_grid - stepsize))] for stepsize in stepsizes])
                        mean_idx = np.argmin(np.abs(stepsize_grid - np.mean(stepsizes)))
                        
                        achieved_neff[tuning_target].append(np.mean(_neffs))
                        achieved_neff_err[tuning_target].append([np.min(_neffs), np.max(_neffs)])
                        good_neff[tuning_target].append(rhat[mean_idx] < rhat_threshold)
                        
    x = np.array(x)
    
    fontsize = 7
    
    a = .2 # alpha
    d = .06

    #neff_esjds_err = np.array((np.abs(np.array(neff_esjds_err)[:,0] - neff_esjds), np.abs(np.array(neff_esjds_err)[:,1] - neff_esjds)))
    #neff_accs_err = np.array((np.abs(np.array(neff_accs_err)[:,0] - neff_accs), np.abs(np.array(neff_accs_err)[:,1] - neff_accs)))
    
    #good_esjds = np.array(good_esjds)
    #good_accs = np.array(good_accs)
    
    for k, tuning_target in enumerate(achieved_neff):
        err = np.array((
            np.abs(np.array(achieved_neff_err[tuning_target])[:,0] - achieved_neff[tuning_target]), 
            np.abs(np.array(achieved_neff_err[tuning_target])[:,1] - achieved_neff[tuning_target])
        ))
        good = np.array(good_neff[tuning_target])
        
        _d = (k - len(achieved_neff) / 2) * d
        
        for i, _x in enumerate(x):
            facecolors = 'none' if not good[i] else cm.tab20(k)
            if i == 0:
                ax.scatter(_x+_d, achieved_neff[tuning_target][i], zorder=1, color=cm.tab20(k), 
                                                                   label=tuning_target.replace('Rate\n', 'rate '), 
                                                                   marker='o', 
                                                                   alpha=(1-a)*good[i]+a, facecolors=facecolors)
            else:
                ax.scatter(_x+_d, achieved_neff[tuning_target][i], zorder=1, color=cm.tab20(k), 
                                                                   marker='o', 
                                                                   alpha=(1-a)*good[i]+a, facecolors=facecolors)

        for i, _x in enumerate(x):
            ax.errorbar(_x+_d, achieved_neff[tuning_target][i], yerr=err[:,i:i+1], zorder=-1, fmt='none', 
                                                                                   color=cm.tab20(k), alpha=(1-a)*good[i]+a)
    
    for i, _x in enumerate(x):
        _neffs = []
        for tuning_target in achieved_neff:
            _neffs.append(achieved_neff[tuning_target][i])
        neff_max = max(_neffs)
        ax.plot([_x, _x], [0, neff_max], linestyle='dashed', linewidth=.1, color='black', alpha=.5, zorder=-10)

    if normalize:
        ax.set_ylim((-0.02082685768576953, 1.0486108027469414))
    
    ax.legend(bbox_to_anchor=(0.5, -0.4), loc='upper center', ncol=len(achieved_neff))
    
    ub = ax.get_ylim()[1] / 1.05
    
    if group_by_proposal:
        x_group = [((len(problems) - 1) / 2) + (len(problems) + 2) * i for i in range(len(proposals))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(proposals)[i], ha='center', va='center')
            
        xlabels = [name for name in problems] * len(proposals)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
    else:
        x_group = [((len(problems)-1) / 2) + (len(proposals) + 2) * i for i in range(len(problems))]
        for i, _x in enumerate(x_group):
            ax.text(_x, 1.1 * ub, list(problems)[i], ha='center', va='center')
            
        xlabels = [name.replace('-And-', ' & ') for name in proposals] * len(problems)
        ax.set_xticks(x)
        ax.set_xticklabels(xlabels, rotation=75, fontsize=fontsize)
            
    fig.savefig("img/neff-t-tuning-results" + fmt, bbox_inches='tight')
    plt.show()


In [None]:
with open("data/posterior", "rb") as fhandle:
    posterior = dill.load(fhandle)
    
with open("data/prior", "rb") as fhandle:
    prior = dill.load(fhandle)

print(list(prior))

h1, h2 = 1, 3
combine_prior = True
combine_posterior = True

data = []

# prior 1d, prior 2d, posterior 1d, posterior 2d
show = {
    "Gauss": (False, False, True, True),
    "STAT-1": (True, True, True, True),
    "STAT-2": (True, True, True, True),
    "STAT-1-ni": (True, True, True, True),
    "STAT-2-ni": (True, True, True, True),
}

all_ticks = {
    "Gauss": [[-2, 0, 2]] * 20,
    "STAT-1": [[0, 0.5, 1, 1.5, 2], [0, 0.5, 1, 1.5, 2]],
    "STAT-2": [[0, 0.5, 1, 1.5, 2], [0, 0.5, 1, 1.5, 2]],
    "STAT-1-ni": [[0, 0.5, 1, 1.5, 2], [0, 0.5, 1, 1.5, 2], [0, 20, 40, 60, 80, 100]],
    "STAT-2-ni": [[0, 0.5, 1, 1.5, 2], [0, 0.5, 1, 1.5, 2], [0, 20, 40, 60, 80, 100]],
}

N = 500

#data += [prior] if show_prior else []
#data += [posterior] if show_posterior else []

styles = [
    {"color": 'C0'},
    {"color": 'C1'}
]

pairplot_figsize=2*np.array([3.2, 3.2])

for key in prior:
    ticks = all_ticks[key]
    show_prior_1d, show_prior_2d, show_posterior_1d, show_posterior_2d = show[key]

    dim = prior[key].shape[-1] 
    dim = dim if dim < 20 else 5
    axs = np.array([[None]*(dim+1)]*(dim))
    
    n = h2*(dim-1) + h1

    #fig = plt.figure(figsize=(dim*3.2, dim*3.2), dpi=dpi)
    fig = plt.figure(figsize=pairplot_figsize, dpi=dpi, constrained_layout=True)
    fig.set_size_inches(*pairplot_figsize)
    gs = fig.add_gridspec(n, n)

    for i in range(dim-1):
        axs[i, i] = fig.add_subplot(gs[i*h2:i*h2+h1, i*h2:i*h2+h2])
        axs[i, i].spines['top'].set_visible(False)
        axs[i, i].spines['left'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].axes.get_yaxis().set_visible(False)
        axs[i, i].patch.set_alpha(0)

    for i in range(1, dim):
        axs[i, i+1] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, i*h2:i*h2+h1])
        axs[i, i+1].spines['top'].set_visible(False)
        axs[i, i+1].spines['right'].set_visible(False)
        axs[i, i+1].spines['bottom'].set_visible(False)
        axs[i, i+1].axes.get_xaxis().set_visible(False)
        axs[i, i+1].patch.set_alpha(0)

    for i in range(1, dim):
        for j in range(i):
            axs[i, j] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, j*h2:(j+1)*h2])

    if hasattr(problems[key]['default'][1].model, "model"):
        param_names = problems[key]['default'][1].model.model.parameter_names
    else:
        param_names = ["$x_{" + str(i) + "}$" for i in range(1, dim+1)]
    
    for i, name in enumerate(param_names):
        if name.find('.x') >= 0:
            param_names[i] = "$" + name.replace('.x','') + "_{\mathrm{xch}}$"
        if name.find('.n') >= 0:
            param_names[i] = "$" + name.replace('.n','') + "$"    
            
            
    l = 0
    states = prior
    thinning = int(states[key].shape[1] / N)
    if show_prior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_prior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_prior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])

    if show_prior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_prior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=2./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=2./(dim-1)**2)

                

    l = 1
    states = posterior
    thinning = int(states[key].shape[1] / N)
    if show_posterior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_posterior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_posterior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])


    if show_posterior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_posterior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=2./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=2./(dim-1)**2)
        print(hopsy.rhat(states[key]))
        
    for i in range(dim-1):
        for j in range(i, dim-1):
            if j+1 == dim-1:
                axs[j+1, i].set_xlabel(param_names[i], fontsize=fs)
                #axs[j+1, i].set_xticks(ticks[i], ticks[i], fontsize=int(12/np.sqrt(dim)))
                axs[j+1, i].tick_params(axis='x', labelsize=int(fs/np.sqrt(dim-1)))
            else:
                axs[j+1, i].set_xticks([])

            if i == 0:
                axs[j+1, i].set_ylabel(param_names[j+1], fontsize=fs)
                axs[j+1, i].tick_params(axis='y', labelsize=int(fs/np.sqrt(dim-1)))
            else:
                axs[j+1, i].set_yticks([])

    custom_lines = []
    custom_lines += [Line2D([0], [0], color='C0')] if show_prior_1d or show_prior_2d else []
    custom_lines += [Line2D([0], [0], color='C1')] if show_posterior_1d or show_posterior_2d else []
    
    names = []
    names += ['Prior'] if show_prior_1d or show_prior_2d else []
    names += ['Posterior'] if show_posterior_1d or show_posterior_2d else []
    
    #if (show_prior_1d or show_prior_2d) and (show_posterior_1d or show_posterior_2d):
    fig.legend(custom_lines, names, bbox_to_anchor=(1, 1))#loc=(.7, .8))
    
    fig.savefig("img/" + key + "-full-marginals" + fmt)
    
    plt.show()

In [None]:
with open("data/posterior", "rb") as fhandle:
    posterior = dill.load(fhandle)
    
with open("data/prior", "rb") as fhandle:
    prior = dill.load(fhandle)
    

h1, h2 = 1, 3
combine_prior = True
combine_posterior = True

data = []

show = {
    "Gauss": (False, False, True, True),
    "STAT-1": (True, True, True, True),
    "STAT-2": (True, True, True, True),
    "STAT-1-ni": (True, True, True, True),
    "STAT-2-ni": (True, True, True, True),
}

thinning = 1

#data += [prior] if show_prior else []
#data += [posterior] if show_posterior else []

styles = [
    {"color": 'C0'},
    {"color": 'C1'}
]

restrict_on = {
    "Gauss": ['$x_{1}$', '$x_{2}$'],
    "STAT-1": ['u.n', 'q.n'],
    "STAT-1-ni": ['q.n', 'w.x'], 
    "STAT-2": ['u.n', 'q.n'], 
    "STAT-2-ni":  ['q.n', 'w.x'], 
    #"INST":
}

pairplot_figsize = 1 * np.array([3.2, 3.2])

for key in posterior:
    show_prior_1d, show_prior_2d, show_posterior_1d, show_posterior_2d = show[key]

    if hasattr(problems[key]['default'][1].model, "model"):
        _param_names = problems[key]['default'][1].model.model.parameter_names
    else:
        _param_names = ["$x_{" + str(i) + "}$" for i in range(1, dim+1)]
        
    idx = [i for i, name in enumerate(_param_names) if name in restrict_on[key]]
    param_names = [name for name in _param_names if name in restrict_on[key]]
    dim = len(restrict_on[key])
    
    for i, name in enumerate(param_names):
        if name.find('.x') >= 0:
            param_names[i] = "$" + name.replace('.x','') + "_{\mathrm{xch}}$"
        if name.find('.n') >= 0:
            param_names[i] = "$" + name.replace('.n','') + "$"
    
    print(param_names)
    
#    dim = posterior[key].shape[-1]
    axs = np.array([[None]*(dim+1)]*(dim))
    
    n = h2*(dim-1) + h1

    fig = plt.figure(figsize=pairplot_figsize, dpi=dpi, constrained_layout=True)
    gs = fig.add_gridspec(n, n)

    for i in range(dim-1):
        axs[i, i] = fig.add_subplot(gs[i*h2:i*h2+h1, i*h2:i*h2+h2])
        axs[i, i].spines['top'].set_visible(False)
        axs[i, i].spines['left'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].axes.get_yaxis().set_visible(False)
        axs[i, i].patch.set_alpha(0)

    for i in range(1, dim):
        axs[i, i+1] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, i*h2:i*h2+h1])
        axs[i, i+1].spines['top'].set_visible(False)
        axs[i, i+1].spines['right'].set_visible(False)
        axs[i, i+1].spines['bottom'].set_visible(False)
        axs[i, i+1].axes.get_xaxis().set_visible(False)
        axs[i, i+1].patch.set_alpha(0)

    for i in range(1, dim):
        for j in range(i):
            axs[i, j] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, j*h2:(j+1)*h2])

    l = 0
    states = prior
    if show_prior_1d:
        for i in range(dim-1):
            d = idx[i]
            orientation = 'vertical'

            if combine_prior:
                axs[i, i].hist(states[key][:,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            d = idx[i]
            orientation = 'horizontal'

            if combine_prior:
                axs[i, i+1].hist(states[key][:,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])

    if show_prior_2d:
        for i in range(dim-1):
            d = idx[i]
            for j in range(i, dim-1):
                e = idx[j+1]
                if combine_prior:
                    axs[j+1, i].scatter(states[key][:,::thinning,d].flatten(), 
                                        states[key][:,::thinning,e].flatten(), s=2./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,d], 
                                            states[key][k,::thinning,e], s=2./(dim-1)**2)
                if j+1 == dim-1:
                    axs[j+1, i].set_xlabel(param_names[i], fontsize=fs)
                else:
                    axs[j+1, i].set_xticks([])

                if i == 0:
                    axs[j+1, i].set_ylabel(param_names[j+1], fontsize=fs)
                else:
                    axs[j+1, i].set_yticks([])
                

    l = 1
    states = posterior
    if show_posterior_1d:
        for i in range(dim-1):
            d = idx[i]
            orientation = 'vertical'

            if combine_posterior:
                axs[i, i].hist(states[key][:,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            d = idx[i]
            orientation = 'horizontal'

            if combine_posterior:
                axs[i, i+1].hist(states[key][:,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,d].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])


    if show_posterior_2d:
        for i in range(dim-1):
            d = idx[i]
            for j in range(i, dim-1):
                e = idx[j+1]
                if combine_posterior:
                    axs[j+1, i].scatter(states[key][:,::thinning,d].flatten(), 
                                        states[key][:,::thinning,e].flatten(), s=2./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,d], 
                                            states[key][k,::thinning,e], s=2./(dim-1)**2)
                if j+1 == dim-1:
                    axs[j+1, i].set_xlabel(param_names[i], fontsize=fs)
                else:
                    axs[j+1, i].set_xticks([])

                if i == 0:
                    axs[j+1, i].set_ylabel(param_names[j+1], fontsize=fs)
                else:
                    axs[j+1, i].set_yticks([])
                    
    custom_lines = []
    custom_lines += [Line2D([0], [0], color='C0')] if show_prior_1d or show_prior_2d else []
    custom_lines += [Line2D([0], [0], color='C1')] if show_posterior_1d or show_posterior_2d else []
    
    names = []
    names += ['Prior'] if show_prior_1d or show_prior_2d else []
    names += ['Posterior'] if show_posterior_1d or show_posterior_2d else []
    
    #if (show_prior_1d or show_prior_2d) and (show_posterior_1d or show_posterior_2d):
    if key == "STAT-2":
        fig.legend(custom_lines, names, bbox_to_anchor=(1, 1))#loc=(.7, .8))
    
    fig.savefig("img/" + key + "-marginal" + fmt)#, bbox_inches='tight')

    plt.show()

In [None]:
with open("data/posterior", "rb") as fhandle:
    posterior = dill.load(fhandle)
    
with open("data/prior", "rb") as fhandle:
    prior = dill.load(fhandle)

print(list(prior))

h1, h2 = 1, 5
combine_prior = True
combine_posterior = False

data = []

show_prior_1d = False
show_prior_2d = False
show_posterior_1d = True
show_posterior_2d = True

thinning = 1

#data += [prior] if show_prior else []
#data += [posterior] if show_posterior else []

styles = [
    {"color": 'C0'},
    {"color": 'C1'}
]

for key in ['STAT-2-ni']:
    dim = prior[key].shape[-1] if dim < 20 else 5
    
    axs = np.array([[None]*(dim+1)]*(dim))
    
    n = h2*(dim-1) + h1

    fig = plt.figure(figsize=(6.4, 6.4), dpi=dpi, constrained_layout=True)
    gs = fig.add_gridspec(n, n)

    for i in range(dim-1):
        axs[i, i] = fig.add_subplot(gs[i*h2:i*h2+h1, i*h2:i*h2+h2])
        axs[i, i].spines['top'].set_visible(False)
        axs[i, i].spines['left'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].spines['right'].set_visible(False)
        axs[i, i].axes.get_yaxis().set_visible(False)
        axs[i, i].patch.set_alpha(0)

    for i in range(1, dim):
        axs[i, i+1] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, i*h2:i*h2+h1])
        axs[i, i+1].spines['top'].set_visible(False)
        axs[i, i+1].spines['right'].set_visible(False)
        axs[i, i+1].spines['bottom'].set_visible(False)
        axs[i, i+1].axes.get_xaxis().set_visible(False)
        axs[i, i+1].patch.set_alpha(0)

    for i in range(1, dim):
        for j in range(i):
            axs[i, j] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, j*h2:(j+1)*h2])

    if hasattr(problems[key]['default'][1].model, "model"):
        param_names = problems[key]['default'][1].model.model.parameter_names
    else:
        param_names = ["$x_{" + str(i) + "}$" for i in range(1, dim+1)]
    
    for i, name in enumerate(param_names):
        if name.find('.x') >= 0:
            param_names[i] = "$" + name.replace('.x','') + "_{\mathrm{xch}}$"
        if name.find('.n') >= 0:
            param_names[i] = "$" + name.replace('.n','') + "$"    
            
            
    l = 0
    states = prior
    if show_prior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_prior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_prior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])

    if show_prior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_prior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=1./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=1./(dim-1)**2)
                if j+1 == dim-1:
                    axs[j+1, i].set_xlabel(param_names[i])
                else:
                    axs[j+1, i].set_xticks([])

                if i == 0:
                    axs[j+1, i].set_ylabel(param_names[j+1])
                else:
                    axs[j+1, i].set_yticks([])
                

    l = 1
    states = posterior
    if show_posterior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_posterior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_posterior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])


    if show_posterior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_posterior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=1./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=1./(dim-1)**2)
                if j+1 == dim-1:
                    axs[j+1, i].set_xlabel(param_names[i], fontsize=fs)
                else:
                    axs[j+1, i].set_xticks([])

                if i == 0:
                    axs[j+1, i].set_ylabel(param_names[j+1], fontsize=fs)
                else:
                    axs[j+1, i].set_yticks([])
                    
        print(hopsy.rhat(states[key]))
    
    fig.savefig("img/" + key + "-posterior-corner-plot" + fmt)#, bbox_inches='tight')
    
    plt.show()

# Simplicus prior

In [None]:
with open("data/simplicus", "rb") as fhandle:
    param_names, prior, posterior = dill.load(fhandle)
    
h1, h2 = 0, 3
combine_prior = True
combine_posterior = True

data = []

# prior 1d, prior 2d, posterior 1d, posterior 2d
show = {
    "Simplicus": (False, True, False, False),
}

all_ticks = {
    "Simplicus": [[0, 0.5, 1, 1.5, 2], [0, 0.5, 1, 1.5, 2]],
}

N = 10000

#data += [prior] if show_prior else []
#data += [posterior] if show_posterior else []

styles = [
    {"color": 'C0'},
    {"color": 'C1'}
]

for key in prior:
    ticks = all_ticks[key]
    show_prior_1d, show_prior_2d, show_posterior_1d, show_posterior_2d = show[key]

    dim = prior[key].shape[-1] 
    dim = dim if dim < 20 else 5
    axs = np.array([[None]*(dim+1)]*(dim))
    
    n = h2*(dim-1) + h1

    #fig = plt.figure(figsize=(dim*3.2, dim*3.2), dpi=dpi)
    fig = plt.figure(figsize=(6.4, 6.4), dpi=dpi, constrained_layout=True)
    fig.set_size_inches(6.4, 6.4)
    gs = fig.add_gridspec(n, n)

    for i in range(1, dim):
        for j in range(i):
            axs[i, j] = fig.add_subplot(gs[(i-1)*h2+h1:i*h2+h1, j*h2:(j+1)*h2])

    for i, name in enumerate(param_names):
        if name.find('.x') >= 0:
            param_names[i] = "$" + name.replace('.x','') + "_{\mathrm{xch}}$"
        if name.find('.n') >= 0:
            param_names[i] = "$" + name.replace('.n','') + "$"    
            
            
    l = 0
    states = prior
    thinning = int(states[key].shape[1] / N)
    thinning = 1 if thinning == 0 else thinning
    if show_prior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_prior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_prior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])

    if show_prior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_prior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=1./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=1./(dim-1)**2)

                

    l = 1
    states = posterior
    thinning = int(states[key].shape[1] / N)
    if show_posterior_1d:
        for i in range(dim-1):
            orientation = 'vertical'

            if combine_posterior:
                axs[i, i].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i].set_xticks([])
            if i > 0:
                axs[i, i].set_yticks([])

        for i in range(1, dim):
            orientation = 'horizontal'

            if combine_posterior:
                axs[i, i+1].hist(states[key][:,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation, **styles[l])
            else:
                for k in range(n_chains):
                    axs[i, i+1].hist(states[key][k,::thinning,i].flatten(), density=True, histtype='step', orientation=orientation)

            if i+1 < dim:
                axs[i, i+1].set_xticks([])
            if i > 0:
                axs[i, i+1].set_yticks([])


    if show_posterior_2d:
        for i in range(dim-1):
            for j in range(i, dim-1):
                if combine_posterior:
                    axs[j+1, i].scatter(states[key][:,::thinning,i].flatten(), 
                                        states[key][:,::thinning,j+1].flatten(), s=1./(dim-1)**2, **styles[l])
                else:
                    for k in range(n_chains):
                        axs[j+1, i].scatter(states[key][k,::thinning,i], 
                                            states[key][k,::thinning,j+1], s=1./(dim-1)**2)
        print(hopsy.rhat(states[key]))
        
    for i in range(dim-1):
        for j in range(i, dim-1):
            if j+1 == dim-1:
                axs[j+1, i].set_xlabel(param_names[i], fontsize=fs)
                #axs[j+1, i].set_xticks(ticks[i], ticks[i], fontsize=int(12/np.sqrt(dim)))
                axs[j+1, i].tick_params(axis='x', labelsize=int(fs/np.sqrt(dim-1)))
            else:
                axs[j+1, i].set_xticks([])

            if i == 0:
                axs[j+1, i].set_ylabel(param_names[j+1], fontsize=fs)
                axs[j+1, i].tick_params(axis='y', labelsize=int(fs/np.sqrt(dim-1)))
            else:
                axs[j+1, i].set_yticks([])

    custom_lines = []
    custom_lines += [Line2D([0], [0], color='C0')] if show_prior_1d or show_prior_2d else []
    custom_lines += [Line2D([0], [0], color='C1')] if show_posterior_1d or show_posterior_2d else []
    
    names = []
    names += ['Prior'] if show_prior_1d or show_prior_2d else []
    names += ['Posterior'] if show_posterior_1d or show_posterior_2d else []
    
    #if (show_prior_1d or show_prior_2d) and (show_posterior_1d or show_posterior_2d):
    #fig.legend(custom_lines, names, bbox_to_anchor=(1, 1))#loc=(.7, .8))
    
    fig.savefig("img/feasible-flux-space" + fmt)
    
    plt.show()