In [1]:
import sys
sys.path.append("..")

from matplotlib import pyplot as plt

import torch
import numpy as np
from torch.distributions import Beta

import pyro
from pyro.infer.mcmc import MCMC
import pyro.distributions as dist

from kernel.sgnuts import NUTS

In [2]:
pyro.clear_param_store()

pyro.set_rng_seed(101)

# create some data with 6 observed heads and 4 observed tails
x1 = torch.ones(10)
x1[0:6] = torch.zeros(6)

x2 = torch.ones(10)
x2[0:2] = torch.zeros(2)

    
def model(x1, x2, alpha0=1., beta0=1.):
    alpha0 = torch.tensor(alpha0)
    beta0 = torch.tensor(beta0)
    
    f1 = pyro.sample("coin1", dist.Beta(alpha0, beta0))
    f2 = pyro.sample("coin2", dist.Beta(alpha0, beta0))
    
    return pyro.sample("obs1", dist.Bernoulli(f1), obs=x1), pyro.sample("obs2", dist.Bernoulli(f2), obs=x2)

In [3]:
sgnuts_kernel = NUTS(model, 
                subsample_positions=[0, 1],
                batch_size=5,
                potential_fn=None,
                learning_rate=0.01, 
                momentum_decay=0.1,
                resample_every_n=50, 
                obs_info_noise=False, 
                compute_obs_info='every_sample',
                use_multinomial_sampling=True,
                max_tree_depth=10,
                )

sgnuts_mcmc = MCMC(sgnuts_kernel, num_samples=1000)

sgnuts_mcmc.run(x1, x2)
sgnuts_samples = sgnuts_mcmc.get_samples()

Warmup:   0%|                                                           | 2/2000 [00:00, 12.08it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([0.1107, 0.1285])}
{('coin1', 'coin2'): tensor([-0.1894,  0.1033])}
{('coin1', 'coin2'): tensor([ 0.1092, -0.1557])}
{('coin1', 'coin2'): tensor([-0.0646,  0.4351])}
{('coin1', 'coin2'): tensor([-0.0539,  0.5016])}
{('coin1', 'coin2'): tensor([-0.3059,  0.5223])}
{('coin1', 'coin2'): tensor([-0.0192,  0.1576])}
{('coin1', 'coin2'): tensor([-0.0959,  0.2100])}
{('coin1', 'coin2'): tensor([-0.2179,  0.1513])}
{('coin1', 'coin2'): tensor([-0.2483,  0.3094])}
{('coin1', 'coin2'): tensor([-0.4650,  0.5576])}
{('coin1', 'coin2'): tensor([-0.3985,  0.5033])}
{('coin1', 'coin2'): tensor([-0.1984,  0.5860])}
{('coin1', 'coin2'): tensor([0.2874, 0.2400])}
{('coin1', 'coin2'): tensor([0.2704, 0.2082])}
{('coin1', 'coin2'): tensor([0.2807, 0.3694])}
{('coin1', 'coin2'): tensor([0.1486, 0.4373])}
{('coin1', 'coin2'): tensor([0.2228, 0.3257])}
{('coin1', 'coin2'): tensor([-0.4027,  0.2337])}
{('coin1', 'coin2'): tensor([-0.4790,  0.1014])}
{('coin1', 'coin2'): tensor([-0.

Warmup:   0%|▏                                                          | 7/2000 [00:00, 20.27it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 0.9261, -0.7908])}
{('coin1', 'coin2'): tensor([ 0.8454, -1.2378])}
{('coin1', 'coin2'): tensor([ 0.7290, -1.4120])}
{('coin1', 'coin2'): tensor([ 0.2984, -1.4734])}
{('coin1', 'coin2'): tensor([-0.1709, -1.3819])}
{('coin1', 'coin2'): tensor([-0.5376, -1.6550])}
{('coin1', 'coin2'): tensor([-0.5682, -1.5566])}
{('coin1', 'coin2'): tensor([-0.1067,  0.4069])}
{('coin1', 'coin2'): tensor([-0.2277, -0.0055])}
{('coin1', 'coin2'): tensor([-0.6666,  0.0072])}
{('coin1', 'coin2'): tensor([-0.9747, -0.3632])}
{('coin1', 'coin2'): tensor([-1.0132, -0.5673])}
{('coin1', 'coin2'): tensor([-0.8484, -0.5486])}
{('coin1', 'coin2'): tensor([-0.7926, -0.3238])}
{('coin1', 'coin2'): tensor([-0.2040, -1.7463])}
{('coin1', 'coin2'): tensor([-0.2572, -1.8986])}
{('coin1', 'coin2'): tensor([ 0.1493, -1.3590])}
{('coin1', 'coin2'): tensor([-0.1987, -1.5523])}
{('coin1', 'coin2'): tensor([-0.1925, -1.0072])}
{('coin1', 'coin2'): tensor([-0.0026, -0.7208])}
{('coin1', 'coin2'):

Warmup:   0%|▎                                                         | 10/2000 [00:00, 16.49it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 0.8231, -0.2900])}
{('coin1', 'coin2'): tensor([0.9936, 0.0233])}
{('coin1', 'coin2'): tensor([1.1938, 0.3374])}
{('coin1', 'coin2'): tensor([1.3760, 0.5433])}
{('coin1', 'coin2'): tensor([1.3255, 0.6777])}
{('coin1', 'coin2'): tensor([0.8715, 0.5103])}
{('coin1', 'coin2'): tensor([0.5573, 0.4696])}
{('coin1', 'coin2'): tensor([0.4658, 0.9300])}
{('coin1', 'coin2'): tensor([0.7337, 1.0744])}
{('coin1', 'coin2'): tensor([0.6962, 1.2489])}
{('coin1', 'coin2'): tensor([0.3619, 1.4205])}
{('coin1', 'coin2'): tensor([-0.0784,  1.3421])}
{('coin1', 'coin2'): tensor([-0.0410,  1.3817])}
{('coin1', 'coin2'): tensor([0.1369, 1.1529])}
{('coin1', 'coin2'): tensor([-0.9650,  0.0721])}
{('coin1', 'coin2'): tensor([-0.3571,  0.1386])}
{('coin1', 'coin2'): tensor([-0.2297,  0.2869])}
{('coin1', 'coin2'): tensor([-1.9964, -0.8477])}
{('coin1', 'coin2'): tensor([-2.2477, -0.8757])}
{('coin1', 'coin2'): tensor([-2.3833, -1.0187])}
{('coin1', 'coin2'): tensor([-2.0974, -1.4

Warmup:   1%|▎                                                         | 12/2000 [00:00, 15.22it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([1.0733, 0.5369])}
{('coin1', 'coin2'): tensor([0.5004, 0.7486])}
{('coin1', 'coin2'): tensor([0.1355, 0.7312])}
{('coin1', 'coin2'): tensor([-0.2121,  0.4553])}
{('coin1', 'coin2'): tensor([-0.1681,  0.4443])}
{('coin1', 'coin2'): tensor([-0.1249,  0.4568])}
{('coin1', 'coin2'): tensor([-0.5709,  0.3144])}
{('coin1', 'coin2'): tensor([-0.8144,  0.5708])}
{('coin1', 'coin2'): tensor([-0.6223,  0.5379])}
{('coin1', 'coin2'): tensor([-0.9789,  0.4524])}
{('coin1', 'coin2'): tensor([-1.1974,  0.6025])}
{('coin1', 'coin2'): tensor([-1.1945,  0.3094])}
{('coin1', 'coin2'): tensor([-0.9031,  0.3635])}
{('coin1', 'coin2'): tensor([-0.7777, -0.0139])}
{('coin1', 'coin2'): tensor([-0.9543,  0.0281])}
{('coin1', 'coin2'): tensor([0.9573, 1.5849])}
{('coin1', 'coin2'): tensor([1.5679, 1.3040])}
{('coin1', 'coin2'): tensor([1.7556, 1.4346])}
{('coin1', 'coin2'): tensor([2.0902, 1.1444])}
{('coin1', 'coin2'): tensor([2.4018, 0.8229])}
{('coin1', 'coin2'): tensor([1.9868,

Warmup:   1%|▍                                                         | 14/2000 [00:00, 15.38it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 2.2106, -0.0487])}
{('coin1', 'coin2'): tensor([ 2.2555, -0.1938])}
{('coin1', 'coin2'): tensor([2.1818, 0.0678])}
{('coin1', 'coin2'): tensor([2.2090, 0.0741])}
{('coin1', 'coin2'): tensor([2.0806, 0.2430])}
{('coin1', 'coin2'): tensor([1.7463, 0.0223])}
{('coin1', 'coin2'): tensor([ 1.4674, -0.1405])}
{('coin1', 'coin2'): tensor([1.3573, 0.2151])}
{('coin1', 'coin2'): tensor([-1.6080,  0.9283])}
{('coin1', 'coin2'): tensor([-1.7007,  0.4602])}
{('coin1', 'coin2'): tensor([-1.4405,  0.3828])}
{('coin1', 'coin2'): tensor([-0.8937,  0.5593])}
{('coin1', 'coin2'): tensor([-0.3892,  0.8641])}
{('coin1', 'coin2'): tensor([-0.3495,  0.6266])}
{('coin1', 'coin2'): tensor([-0.1488,  0.4561])}
{('coin1', 'coin2'): tensor([-1.3781,  0.4174])}
{('coin1', 'coin2'): tensor([-0.9830,  0.5402])}
{('coin1', 'coin2'): tensor([-0.6073,  0.9235])}
{('coin1', 'coin2'): tensor([-0.3359,  0.9238])}
{('coin1', 'coin2'): tensor([-0.2264,  0.5966])}
{('coin1', 'coin2'): tensor([-

Warmup:   1%|▍                                                         | 17/2000 [00:01, 16.52it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([1.2578, 1.2075])}
{('coin1', 'coin2'): tensor([1.0260, 1.0503])}
{('coin1', 'coin2'): tensor([0.9259, 1.1474])}
{('coin1', 'coin2'): tensor([0.9022, 1.1803])}
{('coin1', 'coin2'): tensor([0.6088, 0.9774])}
{('coin1', 'coin2'): tensor([0.1525, 0.4732])}
{('coin1', 'coin2'): tensor([0.0489, 0.4896])}
{('coin1', 'coin2'): tensor([0.1163, 0.4603])}
{('coin1', 'coin2'): tensor([-0.2882,  0.3264])}
{('coin1', 'coin2'): tensor([0.6730, 0.0475])}
{('coin1', 'coin2'): tensor([ 0.2749, -0.2012])}
{('coin1', 'coin2'): tensor([-0.0083, -0.5541])}
{('coin1', 'coin2'): tensor([-0.3645, -0.6718])}
{('coin1', 'coin2'): tensor([-0.7082, -0.6804])}
{('coin1', 'coin2'): tensor([-0.9740, -0.4219])}
{('coin1', 'coin2'): tensor([-1.1584, -0.6781])}
{('coin1', 'coin2'): tensor([-1.1245, -0.9978])}
{('coin1', 'coin2'): tensor([-0.8222, -0.4608])}
{('coin1', 'coin2'): tensor([-1.6103, -0.8281])}
{('coin1', 'coin2'): tensor([-1.8611, -0.6902])}
{('coin1', 'coin2'): tensor([-1.7578, 

Warmup:   1%|▌                                                         | 19/2000 [00:01, 14.73it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.6389,  0.9921])}
{('coin1', 'coin2'): tensor([-0.6077,  0.7191])}
{('coin1', 'coin2'): tensor([-0.6858,  0.1390])}
{('coin1', 'coin2'): tensor([-0.8601,  0.4333])}
{('coin1', 'coin2'): tensor([-0.4987,  0.2322])}
{('coin1', 'coin2'): tensor([-0.5546,  0.4314])}
{('coin1', 'coin2'): tensor([-0.5496,  0.2061])}
{('coin1', 'coin2'): tensor([-0.4024, -0.0869])}
{('coin1', 'coin2'): tensor([-0.1384, -0.1830])}
{('coin1', 'coin2'): tensor([-0.1521, -0.4336])}
{('coin1', 'coin2'): tensor([-0.2038, -0.4043])}
{('coin1', 'coin2'): tensor([-0.0528, -0.3405])}
{('coin1', 'coin2'): tensor([-2.3878, -0.1754])}
{('coin1', 'coin2'): tensor([-2.2980,  0.5983])}
{('coin1', 'coin2'): tensor([-1.8428,  1.0846])}
{('coin1', 'coin2'): tensor([-1.6542,  0.8876])}
{('coin1', 'coin2'): tensor([-1.3690,  0.5361])}
{('coin1', 'coin2'): tensor([-1.1201,  0.7724])}
{('coin1', 'coin2'): tensor([-0.1553,  0.8482])}
{('coin1', 'coin2'): tensor([-2.1940, -0.1011])}
{('coin1', 'coin2'):

Warmup:   1%|▋                                                         | 23/2000 [00:01, 14.76it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.4803,  0.7166])}
{('coin1', 'coin2'): tensor([0.1613, 0.9881])}
{('coin1', 'coin2'): tensor([0.5507, 0.0831])}
{('coin1', 'coin2'): tensor([ 0.6515, -0.2934])}
{('coin1', 'coin2'): tensor([ 0.5292, -0.3769])}
{('coin1', 'coin2'): tensor([ 0.7861, -0.2684])}
{('coin1', 'coin2'): tensor([ 1.1603, -0.4049])}
{('coin1', 'coin2'): tensor([ 0.9281, -0.3202])}
{('coin1', 'coin2'): tensor([ 0.9613, -0.2610])}
{('coin1', 'coin2'): tensor([0.5735, 0.1706])}
{('coin1', 'coin2'): tensor([0.8018, 0.5051])}
{('coin1', 'coin2'): tensor([0.9579, 0.4497])}
{('coin1', 'coin2'): tensor([ 0.3491, -0.2548])}
{('coin1', 'coin2'): tensor([ 0.0258, -0.2361])}
{('coin1', 'coin2'): tensor([ 0.1178, -0.2471])}
{('coin1', 'coin2'): tensor([-0.0513, -0.2326])}
{('coin1', 'coin2'): tensor([-0.2195, -0.2272])}
{('coin1', 'coin2'): tensor([-0.5949, -0.1510])}
{('coin1', 'coin2'): tensor([-0.8491, -0.2450])}
{('coin1', 'coin2'): tensor([-1.0334, -0.1818])}
{('coin1', 'coin2'): tensor([0

Warmup:   1%|▋                                                         | 25/2000 [00:01, 14.36it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-1.0822,  1.7910])}
{('coin1', 'coin2'): tensor([-0.9992,  1.6131])}
{('coin1', 'coin2'): tensor([-0.8915,  1.4586])}
{('coin1', 'coin2'): tensor([-0.9918,  1.1270])}
{('coin1', 'coin2'): tensor([-0.9571,  0.7820])}
{('coin1', 'coin2'): tensor([-0.9777,  0.5576])}
{('coin1', 'coin2'): tensor([-0.9780,  0.3482])}
{('coin1', 'coin2'): tensor([-0.9375, -0.1291])}
{('coin1', 'coin2'): tensor([-0.9628, -0.5690])}
{('coin1', 'coin2'): tensor([-0.8020, -0.7262])}
{('coin1', 'coin2'): tensor([-0.2799, -0.9092])}
{('coin1', 'coin2'): tensor([-0.2521, -1.3265])}
{('coin1', 'coin2'): tensor([-0.0725, -1.2007])}
{('coin1', 'coin2'): tensor([ 0.2611, -1.5056])}
{('coin1', 'coin2'): tensor([ 0.5819, -1.1985])}
{('coin1', 'coin2'): tensor([ 0.8728, -1.3392])}
{('coin1', 'coin2'): tensor([ 0.7496, -1.2053])}
{('coin1', 'coin2'): tensor([ 0.5423, -1.2360])}
{('coin1', 'coin2'): tensor([ 0.6708, -1.1313])}
{('coin1', 'coin2'): tensor([ 0.7256, -1.1354])}
{('coin1', 'coin2'):

Warmup:   1%|▊                                                         | 28/2000 [00:01, 15.23it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 1.6867, -0.3195])}
{('coin1', 'coin2'): tensor([ 1.8512, -0.6290])}
{('coin1', 'coin2'): tensor([ 1.5326, -0.6438])}
{('coin1', 'coin2'): tensor([ 1.3393, -0.5928])}
{('coin1', 'coin2'): tensor([ 1.2881, -0.4570])}
{('coin1', 'coin2'): tensor([ 0.8081, -0.0300])}
{('coin1', 'coin2'): tensor([0.9495, 0.7010])}
{('coin1', 'coin2'): tensor([0.6764, 0.7848])}
{('coin1', 'coin2'): tensor([0.5032, 0.6761])}
{('coin1', 'coin2'): tensor([0.1972, 0.9182])}
{('coin1', 'coin2'): tensor([-0.0528,  0.5725])}
{('coin1', 'coin2'): tensor([-0.2810,  0.4130])}
{('coin1', 'coin2'): tensor([-0.2685,  0.5212])}
{('coin1', 'coin2'): tensor([0.0062, 0.3888])}
{('coin1', 'coin2'): tensor([-0.1627, -0.0542])}
{('coin1', 'coin2'): tensor([-0.2599, -0.0093])}
{('coin1', 'coin2'): tensor([-0.1248,  0.0477])}
{('coin1', 'coin2'): tensor([-0.1682, -0.3206])}
{('coin1', 'coin2'): tensor([-0.3899, -0.4471])}
{('coin1', 'coin2'): tensor([-0.3805, -0.2911])}
{('coin1', 'coin2'): tensor([-

Warmup:   2%|▉                                                         | 33/2000 [00:02, 17.28it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.9921, -1.0542])}
{('coin1', 'coin2'): tensor([-1.4061, -1.0602])}
{('coin1', 'coin2'): tensor([-1.2508, -1.1113])}
{('coin1', 'coin2'): tensor([-1.1038, -0.9269])}
{('coin1', 'coin2'): tensor([-1.3409, -0.9668])}
{('coin1', 'coin2'): tensor([-1.4147, -0.9738])}
{('coin1', 'coin2'): tensor([-1.5221, -1.2885])}
{('coin1', 'coin2'): tensor([-1.6765, -1.1199])}
{('coin1', 'coin2'): tensor([-0.6252,  0.6600])}
{('coin1', 'coin2'): tensor([-0.3693,  0.8324])}
{('coin1', 'coin2'): tensor([-0.3651,  0.9440])}
{('coin1', 'coin2'): tensor([-0.5683,  0.1532])}
{('coin1', 'coin2'): tensor([-0.5019, -0.0764])}
{('coin1', 'coin2'): tensor([-0.5028,  0.0678])}
{('coin1', 'coin2'): tensor([-0.3497,  0.2449])}
{('coin1', 'coin2'): tensor([-0.5245,  0.8573])}
{('coin1', 'coin2'): tensor([0.0340, 0.4410])}
{('coin1', 'coin2'): tensor([0.4361, 0.4437])}
{('coin1', 'coin2'): tensor([0.4695, 0.3991])}
{('coin1', 'coin2'): tensor([ 0.5936, -0.0261])}
{('coin1', 'coin2'): tenso

Warmup:   2%|█                                                         | 35/2000 [00:02, 17.05it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([0.0176, 1.0421])}
{('coin1', 'coin2'): tensor([0.2443, 0.9151])}
{('coin1', 'coin2'): tensor([0.3957, 1.1207])}
{('coin1', 'coin2'): tensor([-1.2037, -0.2835])}
{('coin1', 'coin2'): tensor([-1.2429, -0.8007])}
{('coin1', 'coin2'): tensor([-0.8340, -1.0304])}
{('coin1', 'coin2'): tensor([-0.9450, -0.9764])}
{('coin1', 'coin2'): tensor([-0.6186, -0.7773])}
{('coin1', 'coin2'): tensor([-0.5062, -0.7213])}
{('coin1', 'coin2'): tensor([-0.4393, -0.7183])}
{('coin1', 'coin2'): tensor([-0.2736, -0.8385])}
{('coin1', 'coin2'): tensor([-1.3012,  0.8677])}
{('coin1', 'coin2'): tensor([-1.1012,  0.4366])}
{('coin1', 'coin2'): tensor([-1.0127,  0.1280])}
{('coin1', 'coin2'): tensor([-0.9738,  0.9329])}
{('coin1', 'coin2'): tensor([-0.8070,  1.0644])}
{('coin1', 'coin2'): tensor([-0.7193,  1.1208])}
{('coin1', 'coin2'): tensor([-0.6687,  0.8592])}
{('coin1', 'coin2'): tensor([-0.6015,  0.8963])}
{('coin1', 'coin2'): tensor([-0.6330,  0.8050])}
{('coin1', 'coin2'): tenso

Warmup:   2%|█                                                         | 37/2000 [00:02, 14.69it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([0.0714, 1.9548])}
{('coin1', 'coin2'): tensor([0.6128, 1.8057])}
{('coin1', 'coin2'): tensor([0.9916, 1.1051])}
{('coin1', 'coin2'): tensor([1.0349, 1.4656])}
{('coin1', 'coin2'): tensor([1.4218, 1.2282])}
{('coin1', 'coin2'): tensor([1.5865, 0.8226])}
{('coin1', 'coin2'): tensor([1.5580, 0.5437])}
{('coin1', 'coin2'): tensor([1.2953, 0.2489])}
{('coin1', 'coin2'): tensor([ 0.9216, -0.0416])}
{('coin1', 'coin2'): tensor([1.0364, 0.0767])}
{('coin1', 'coin2'): tensor([1.0145, 0.0020])}
{('coin1', 'coin2'): tensor([1.0555, 0.0736])}
{('coin1', 'coin2'): tensor([ 0.9635, -0.4396])}
{('coin1', 'coin2'): tensor([ 0.9926, -0.5233])}
{('coin1', 'coin2'): tensor([ 0.7333, -0.2386])}
{('coin1', 'coin2'): tensor([ 0.1800, -0.2264])}
{('coin1', 'coin2'): tensor([ 0.1063, -0.6081])}
{('coin1', 'coin2'): tensor([-0.0769, -0.5141])}
{('coin1', 'coin2'): tensor([-0.3045, -0.4014])}
{('coin1', 'coin2'): tensor([-0.4818, -0.6691])}
{('coin1', 'coin2'): tensor([-0.8410, -0.5

Warmup:   2%|█▏                                                        | 41/2000 [00:02, 18.08it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.4857,  0.5312])}
{('coin1', 'coin2'): tensor([-0.2384,  0.9157])}
{('coin1', 'coin2'): tensor([1.0762, 0.3492])}
{('coin1', 'coin2'): tensor([ 0.8640, -0.0503])}
{('coin1', 'coin2'): tensor([ 0.1989, -0.2614])}
{('coin1', 'coin2'): tensor([0.7378, 0.7208])}
{('coin1', 'coin2'): tensor([0.4199, 0.6363])}
{('coin1', 'coin2'): tensor([0.4942, 0.7055])}
{('coin1', 'coin2'): tensor([0.2647, 0.5870])}
{('coin1', 'coin2'): tensor([ 0.0600, -0.3362])}
{('coin1', 'coin2'): tensor([-0.3962, -0.1885])}
{('coin1', 'coin2'): tensor([-0.4032, -0.3210])}
{('coin1', 'coin2'): tensor([-0.7504, -0.2812])}
{('coin1', 'coin2'): tensor([-0.8801, -0.4039])}
{('coin1', 'coin2'): tensor([-0.8565, -0.7263])}
{('coin1', 'coin2'): tensor([-0.6519, -0.6976])}
{('coin1', 'coin2'): tensor([-0.8367, -0.6968])}
{('coin1', 'coin2'): tensor([-0.0685,  0.9120])}
{('coin1', 'coin2'): tensor([0.2327, 0.8764])}
{('coin1', 'coin2'): tensor([-0.0418,  0.8807])}
{('coin1', 'coin2'): tensor([0.0

Warmup:   2%|█▏                                                        | 43/2000 [00:02, 14.59it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([2.1249, 1.7001])}
{('coin1', 'coin2'): tensor([1.8802, 1.5650])}
{('coin1', 'coin2'): tensor([1.7921, 1.3317])}
{('coin1', 'coin2'): tensor([1.2868, 1.1244])}
{('coin1', 'coin2'): tensor([1.0469, 1.3394])}
{('coin1', 'coin2'): tensor([1.0015, 1.2565])}
{('coin1', 'coin2'): tensor([0.6234, 0.8151])}
{('coin1', 'coin2'): tensor([0.7046, 1.1715])}
{('coin1', 'coin2'): tensor([0.5667, 1.3124])}
{('coin1', 'coin2'): tensor([0.1649, 0.9477])}
{('coin1', 'coin2'): tensor([-0.2011,  0.8363])}
{('coin1', 'coin2'): tensor([-0.3434,  0.8741])}
{('coin1', 'coin2'): tensor([-0.5256,  1.0318])}
{('coin1', 'coin2'): tensor([-0.6540,  0.7140])}
{('coin1', 'coin2'): tensor([-1.1418,  0.8004])}
{('coin1', 'coin2'): tensor([-1.0009,  0.6321])}
{('coin1', 'coin2'): tensor([-0.5382,  0.2305])}
{('coin1', 'coin2'): tensor([-0.4045,  0.1216])}
{('coin1', 'coin2'): tensor([-0.5732,  0.2105])}
{('coin1', 'coin2'): tensor([-0.2583,  0.3198])}
{('coin1', 'coin2'): tensor([0.0579, 0.1

Warmup:   2%|█▎                                                        | 45/2000 [00:02, 13.09it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.3977, -1.6624])}
{('coin1', 'coin2'): tensor([-0.9552, -1.4558])}
{('coin1', 'coin2'): tensor([-1.4095, -1.5521])}
{('coin1', 'coin2'): tensor([-1.8734, -1.5747])}
{('coin1', 'coin2'): tensor([-2.1775, -1.3993])}
{('coin1', 'coin2'): tensor([-2.6796, -1.4698])}
{('coin1', 'coin2'): tensor([-2.4132, -1.4811])}
{('coin1', 'coin2'): tensor([-2.6033, -1.4004])}
{('coin1', 'coin2'): tensor([-2.3933, -1.5735])}
{('coin1', 'coin2'): tensor([-2.4417, -1.4538])}
{('coin1', 'coin2'): tensor([-2.0724, -1.4694])}
{('coin1', 'coin2'): tensor([-1.6570, -0.9474])}
{('coin1', 'coin2'): tensor([-1.5103, -1.0425])}
{('coin1', 'coin2'): tensor([1.2568, 1.3614])}
{('coin1', 'coin2'): tensor([1.7091, 1.3772])}
{('coin1', 'coin2'): tensor([1.9011, 1.3218])}
{('coin1', 'coin2'): tensor([2.3619, 1.1937])}
{('coin1', 'coin2'): tensor([2.6409, 1.2361])}
{('coin1', 'coin2'): tensor([2.8575, 1.1317])}
{('coin1', 'coin2'): tensor([2.6993, 0.9352])}
{('coin1', 'coin2'): tensor([2.727

Warmup:   2%|█▎                                                        | 47/2000 [00:03, 14.17it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.9611, -0.1304])}
{('coin1', 'coin2'): tensor([-0.8537, -0.2148])}
{('coin1', 'coin2'): tensor([-1.0066, -0.6200])}
{('coin1', 'coin2'): tensor([-0.9980, -0.6527])}
{('coin1', 'coin2'): tensor([-0.8659, -0.6649])}
{('coin1', 'coin2'): tensor([-0.5498, -0.5277])}
{('coin1', 'coin2'): tensor([-0.0641, -0.0929])}
{('coin1', 'coin2'): tensor([-0.0612, -0.0254])}
{('coin1', 'coin2'): tensor([0.2871, 0.3262])}
{('coin1', 'coin2'): tensor([-0.0546,  0.5114])}
{('coin1', 'coin2'): tensor([ 0.7763, -0.6699])}
{('coin1', 'coin2'): tensor([ 0.5386, -0.4873])}
{('coin1', 'coin2'): tensor([ 0.3325, -0.2067])}
{('coin1', 'coin2'): tensor([ 0.3233, -0.0274])}
{('coin1', 'coin2'): tensor([0.3527, 0.0651])}
{('coin1', 'coin2'): tensor([0.4035, 0.2836])}
{('coin1', 'coin2'): tensor([0.1862, 0.4420])}
{('coin1', 'coin2'): tensor([0.7076, 0.5703])}
{('coin1', 'coin2'): tensor([1.0126, 0.2735])}
{('coin1', 'coin2'): tensor([1.1827, 0.1177])}
{('coin1', 'coin2'): tensor([ 1.14

Warmup:   3%|█▌                                                        | 52/2000 [00:03, 15.01it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 0.4520, -1.3733])}
{('coin1', 'coin2'): tensor([ 0.1342, -1.2281])}
{('coin1', 'coin2'): tensor([-0.0968, -0.8092])}
{('coin1', 'coin2'): tensor([ 1.0704, -0.8458])}
{('coin1', 'coin2'): tensor([ 1.5110, -0.6442])}
{('coin1', 'coin2'): tensor([ 1.3842, -0.3919])}
{('coin1', 'coin2'): tensor([ 1.7062, -0.5140])}
{('coin1', 'coin2'): tensor([-0.0564, -0.3039])}
{('coin1', 'coin2'): tensor([-0.7931, -0.4142])}
{('coin1', 'coin2'): tensor([-1.3459, -0.4109])}
{('coin1', 'coin2'): tensor([-1.4223, -0.1743])}
{('coin1', 'coin2'): tensor([-1.2369,  0.0428])}
{('coin1', 'coin2'): tensor([-1.2426,  0.4114])}
{('coin1', 'coin2'): tensor([-1.4595,  0.6673])}
{('coin1', 'coin2'): tensor([-1.5418,  0.5686])}
{('coin1', 'coin2'): tensor([ 0.3396, -0.7893])}
{('coin1', 'coin2'): tensor([ 0.1528, -1.1369])}
{('coin1', 'coin2'): tensor([-0.2153, -1.5650])}
{('coin1', 'coin2'): tensor([-0.4852, -1.4668])}
{('coin1', 'coin2'): tensor([-0.7260, -1.4460])}
{('coin1', 'coin2'):

Warmup:   3%|█▌                                                        | 55/2000 [00:03, 17.95it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.4773,  0.2050])}
{('coin1', 'coin2'): tensor([-0.3241,  0.2000])}
{('coin1', 'coin2'): tensor([-0.1046,  0.1722])}
{('coin1', 'coin2'): tensor([0.1180, 0.0297])}
{('coin1', 'coin2'): tensor([0.3089, 0.2435])}
{('coin1', 'coin2'): tensor([0.2623, 0.4003])}
{('coin1', 'coin2'): tensor([-0.4804, -0.5529])}
{('coin1', 'coin2'): tensor([-0.3584, -0.7313])}
{('coin1', 'coin2'): tensor([ 0.0693, -0.6782])}
{('coin1', 'coin2'): tensor([ 0.1997, -0.6251])}
{('coin1', 'coin2'): tensor([ 0.1143, -0.5199])}
{('coin1', 'coin2'): tensor([ 0.2285, -0.5358])}
{('coin1', 'coin2'): tensor([ 0.8372, -0.2611])}
{('coin1', 'coin2'): tensor([0.5649, 0.0094])}
{('coin1', 'coin2'): tensor([-0.2581, -0.0939])}
{('coin1', 'coin2'): tensor([-0.0303, -0.2347])}
{('coin1', 'coin2'): tensor([-0.0574, -0.7773])}
{('coin1', 'coin2'): tensor([-0.2429,  0.1272])}
{('coin1', 'coin2'): tensor([-0.3353,  0.5764])}
{('coin1', 'coin2'): tensor([0.0683, 0.5488])}
{('coin1', 'coin2'): tensor([-

Warmup:   3%|█▊                                                        | 61/2000 [00:03, 19.40it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-0.9141, -0.8763])}
{('coin1', 'coin2'): tensor([-0.5030, -0.9385])}
{('coin1', 'coin2'): tensor([-0.2842, -0.9120])}
{('coin1', 'coin2'): tensor([-0.1728, -1.1648])}
{('coin1', 'coin2'): tensor([ 0.0687, -0.9499])}
{('coin1', 'coin2'): tensor([ 0.6460, -0.7401])}
{('coin1', 'coin2'): tensor([ 0.8154, -0.7352])}
{('coin1', 'coin2'): tensor([ 0.5317, -0.4391])}
{('coin1', 'coin2'): tensor([ 0.4704, -0.0364])}
{('coin1', 'coin2'): tensor([ 0.2282, -0.0711])}
{('coin1', 'coin2'): tensor([0.0779, 0.0180])}
{('coin1', 'coin2'): tensor([ 0.0703, -0.4238])}
{('coin1', 'coin2'): tensor([-0.1710, -0.4725])}
{('coin1', 'coin2'): tensor([-0.0262, -0.4516])}
{('coin1', 'coin2'): tensor([1.1045, 0.8606])}
{('coin1', 'coin2'): tensor([0.8155, 0.4153])}
{('coin1', 'coin2'): tensor([0.6772, 0.2276])}
{('coin1', 'coin2'): tensor([ 0.4855, -0.1270])}
{('coin1', 'coin2'): tensor([0.4339, 0.0179])}
{('coin1', 'coin2'): tensor([ 0.5715, -0.2530])}
{('coin1', 'coin2'): tensor([ 

Warmup:   3%|█▊                                                        | 64/2000 [00:03, 18.67it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-1.5365, -0.9842])}
{('coin1', 'coin2'): tensor([-1.3614, -1.2649])}
{('coin1', 'coin2'): tensor([-0.7751, -0.7288])}
{('coin1', 'coin2'): tensor([-1.4736, -0.4377])}
{('coin1', 'coin2'): tensor([-1.3568, -0.3064])}
{('coin1', 'coin2'): tensor([-1.1858,  0.3888])}
{('coin1', 'coin2'): tensor([-0.4812,  0.6265])}
{('coin1', 'coin2'): tensor([-0.6229, -0.2920])}
{('coin1', 'coin2'): tensor([-0.5338, -0.3911])}
{('coin1', 'coin2'): tensor([ 0.0030, -0.5966])}
{('coin1', 'coin2'): tensor([ 0.1557, -0.7937])}
{('coin1', 'coin2'): tensor([ 0.6005, -0.7362])}
{('coin1', 'coin2'): tensor([ 0.4670, -1.0189])}
{('coin1', 'coin2'): tensor([ 0.6388, -0.9742])}
{('coin1', 'coin2'): tensor([ 0.1681, -0.4739])}
{('coin1', 'coin2'): tensor([1.6715, 0.1585])}
{('coin1', 'coin2'): tensor([ 1.5618, -0.1014])}
{('coin1', 'coin2'): tensor([ 1.4506, -0.5255])}
{('coin1', 'coin2'): tensor([ 1.5672, -0.7377])}
{('coin1', 'coin2'): tensor([ 1.5143, -0.5564])}
{('coin1', 'coin2'): t

Warmup:   3%|█▉                                                        | 68/2000 [00:04, 23.17it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 0.7056, -0.9751])}
{('coin1', 'coin2'): tensor([ 0.0115, -0.5097])}
{('coin1', 'coin2'): tensor([-0.6678, -0.2855])}
{('coin1', 'coin2'): tensor([-0.8671,  0.0422])}
{('coin1', 'coin2'): tensor([-1.1419,  0.7054])}
{('coin1', 'coin2'): tensor([-0.8170,  0.1578])}
{('coin1', 'coin2'): tensor([-0.4322, -0.2177])}
{('coin1', 'coin2'): tensor([ 0.0059, -0.4305])}
{('coin1', 'coin2'): tensor([ 0.7151, -0.7620])}
{('coin1', 'coin2'): tensor([ 0.9223, -1.2046])}
{('coin1', 'coin2'): tensor([ 0.9253, -1.4592])}
{('coin1', 'coin2'): tensor([ 0.3509, -0.5164])}
{('coin1', 'coin2'): tensor([ 0.2805, -1.3325])}
{('coin1', 'coin2'): tensor([-0.5755, -1.0777])}
{('coin1', 'coin2'): tensor([-1.2433, -0.5016])}
{('coin1', 'coin2'): tensor([-1.8037, -0.6501])}
{('coin1', 'coin2'): tensor([-2.1083, -0.1511])}
{('coin1', 'coin2'): tensor([-2.3129,  0.4027])}
{('coin1', 'coin2'): tensor([-2.2795,  0.4117])}
{('coin1', 'coin2'): tensor([ 1.8074, -2.0976])}
{('coin1', 'coin2'):

Warmup:   4%|██▏                                                       | 76/2000 [00:04, 29.43it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([ 0.4996, -2.4018])}
{('coin1', 'coin2'): tensor([ 0.1005, -1.9641])}
{('coin1', 'coin2'): tensor([ 0.0774, -1.5056])}
{('coin1', 'coin2'): tensor([ 0.5327, -2.2492])}
{('coin1', 'coin2'): tensor([ 0.3779, -2.2255])}
{('coin1', 'coin2'): tensor([ 0.4940, -2.4325])}
{('coin1', 'coin2'): tensor([ 0.7406, -2.2212])}
{('coin1', 'coin2'): tensor([-0.0700, -0.7179])}
{('coin1', 'coin2'): tensor([-0.1765,  0.2796])}
{('coin1', 'coin2'): tensor([ 0.8887, -1.3171])}
{('coin1', 'coin2'): tensor([ 0.7730, -0.9114])}
{('coin1', 'coin2'): tensor([ 0.6743, -0.5947])}
{('coin1', 'coin2'): tensor([ 0.8995, -1.5590])}
{('coin1', 'coin2'): tensor([ 1.0843, -1.5002])}
{('coin1', 'coin2'): tensor([ 0.6937, -1.5843])}
{('coin1', 'coin2'): tensor([ 0.2147, -1.4934])}
{('coin1', 'coin2'): tensor([ 0.5801, -0.3618])}
{('coin1', 'coin2'): tensor([0.2145, 0.1923])}
{('coin1', 'coin2'): tensor([-0.0184,  0.8889])}
{('coin1', 'coin2'): tensor([0.0669, 1.1940])}
{('coin1', 'coin2'): ten

Warmup:   4%|██▎                                                       | 80/2000 [00:04, 24.76it/s, step size=1.00e-01]

{('coin1', 'coin2'): tensor([-1.0897, -1.5172])}
{('coin1', 'coin2'): tensor([-1.2299, -1.2061])}
{('coin1', 'coin2'): tensor([-1.7173, -0.7084])}
{('coin1', 'coin2'): tensor([-1.7481, -0.2974])}
{('coin1', 'coin2'): tensor([-1.7692,  0.3406])}
{('coin1', 'coin2'): tensor([-2.0337,  0.6175])}
{('coin1', 'coin2'): tensor([-2.0971,  0.6266])}
{('coin1', 'coin2'): tensor([-2.2776,  0.4841])}
{('coin1', 'coin2'): tensor([-1.6554,  0.7318])}
{('coin1', 'coin2'): tensor([-1.6316,  1.0469])}
{('coin1', 'coin2'): tensor([-1.5479,  0.7885])}
{('coin1', 'coin2'): tensor([-1.4024,  0.3928])}
{('coin1', 'coin2'): tensor([-1.0905,  0.4807])}
{('coin1', 'coin2'): tensor([-1.3199,  0.4320])}
{('coin1', 'coin2'): tensor([-1.2522,  0.2492])}
{('coin1', 'coin2'): tensor([-0.8371,  0.0374])}
{('coin1', 'coin2'): tensor([-0.4983,  0.0272])}
{('coin1', 'coin2'): tensor([-0.3728,  0.1671])}
{('coin1', 'coin2'): tensor([-0.2464,  0.0045])}
{('coin1', 'coin2'): tensor([-0.1233,  0.0571])}
{('coin1', 'coin2'):

KeyboardInterrupt: 

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(6, 6), sharex=True)

def analytic_posterior(x, alpha, beta):
    fn = Beta(alpha, beta)
    return np.exp(fn.log_prob(x))
    
x = torch.tensor(np.linspace(0, 1, num=1000))

with torch.no_grad():
    axs[0].hist(sgnuts_samples["coin1"].numpy(), density=True, bins=30)
    axs[0].set(title="SGNUTS 'coin1' samples")
    axs[0].plot(x, analytic_posterior(x, 5., 7.), "r", label="analytic posterior")
    axs[1].hist(sgnuts_samples["coin2"].numpy(), density=True, bins=30)
    axs[1].set(title="SGNUTS 'coin2' samples")
    axs[1].plot(x, analytic_posterior(x, 9., 3.), "r", label="analytic posterior")
plt.xlim((0.0, 1.0))
plt.show()