In [1]:
import sys
sys.path.append("..")

from matplotlib import pyplot as plt

import torch
import numpy as np
from torch.distributions import Beta

import pyro
from pyro.infer.mcmc import MCMC
import pyro.distributions as dist

from kernel.sgnuts import NUTS

In [2]:
pyro.clear_param_store()

pyro.set_rng_seed(101)

# create some data with 6 observed heads and 4 observed tails
x1 = torch.ones(10)
x1[0:6] = torch.zeros(6)

x2 = torch.ones(10)
x2[0:2] = torch.zeros(2)

    
def model(x1, x2, alpha0=1., beta0=1.):
    alpha0 = torch.tensor(alpha0)
    beta0 = torch.tensor(beta0)
    
    f1 = pyro.sample("coin1", dist.Beta(alpha0, beta0))
    f2 = pyro.sample("coin2", dist.Beta(alpha0, beta0))
    
    return pyro.sample("obs1", dist.Bernoulli(f1), obs=x1), pyro.sample("obs2", dist.Bernoulli(f2), obs=x2)

In [3]:
sgnuts_kernel = NUTS(model, 
                subsample_positions=[0, 1],
                batch_size=5,
                potential_fn=None,
                learning_rate=0.01, 
                momentum_decay=0.1,
                resample_every_n=50, 
                obs_info_noise=False, 
                compute_obs_info='every_sample',
                use_multinomial_sampling=True,
                max_tree_depth=10,
                )

sgnuts_mcmc = MCMC(sgnuts_kernel, num_samples=1000)

sgnuts_mcmc.run(x1, x2)
sgnuts_samples = sgnuts_mcmc.get_samples()

Warmup:   0%|                                                           | 2/2000 [00:00, 11.18it/s, step size=1.00e-01]

{'coin1': tensor(0.1107), 'coin2': tensor(0.1285)}
{'coin1': tensor(-0.1894), 'coin2': tensor(0.1033)}
{'coin1': tensor(0.1092), 'coin2': tensor(-0.1557)}
{'coin1': tensor(-0.0646), 'coin2': tensor(0.4351)}
{'coin1': tensor(-0.0539), 'coin2': tensor(0.5016)}
{'coin1': tensor(-0.3059), 'coin2': tensor(0.5223)}
{'coin1': tensor(-0.0192), 'coin2': tensor(0.1576)}
{'coin1': tensor(-0.0959), 'coin2': tensor(0.2100)}
{'coin1': tensor(-0.2179), 'coin2': tensor(0.1513)}
{'coin1': tensor(-0.2483), 'coin2': tensor(0.3094)}
{'coin1': tensor(-0.4650), 'coin2': tensor(0.5576)}
{'coin1': tensor(-0.3985), 'coin2': tensor(0.5033)}
{'coin1': tensor(-0.1984), 'coin2': tensor(0.5860)}
{'coin1': tensor(0.2874), 'coin2': tensor(0.2400)}
{'coin1': tensor(0.2704), 'coin2': tensor(0.2082)}
{'coin1': tensor(0.2807), 'coin2': tensor(0.3694)}
{'coin1': tensor(0.1486), 'coin2': tensor(0.4373)}
{'coin1': tensor(0.2228), 'coin2': tensor(0.3257)}
{'coin1': tensor(-0.4027), 'coin2': tensor(0.2337)}
{'coin1': tensor(-

Warmup:   0%|▏                                                          | 7/2000 [00:00, 20.36it/s, step size=1.00e-01]

{'coin1': tensor(0.3445), 'coin2': tensor(0.9661)}
{'coin1': tensor(0.0900), 'coin2': tensor(0.9033)}
{'coin1': tensor(0.1160), 'coin2': tensor(0.8612)}
{'coin1': tensor(-0.2290), 'coin2': tensor(0.7847)}
{'coin1': tensor(-0.3328), 'coin2': tensor(0.8184)}
{'coin1': tensor(1.0208), 'coin2': tensor(-0.3521)}
{'coin1': tensor(0.9261), 'coin2': tensor(-0.7908)}
{'coin1': tensor(0.8454), 'coin2': tensor(-1.2378)}
{'coin1': tensor(0.7290), 'coin2': tensor(-1.4120)}
{'coin1': tensor(0.2984), 'coin2': tensor(-1.4734)}
{'coin1': tensor(-0.1709), 'coin2': tensor(-1.3819)}
{'coin1': tensor(-0.5376), 'coin2': tensor(-1.6550)}
{'coin1': tensor(-0.5682), 'coin2': tensor(-1.5566)}
{'coin1': tensor(-0.1067), 'coin2': tensor(0.4069)}
{'coin1': tensor(-0.2277), 'coin2': tensor(-0.0055)}
{'coin1': tensor(-0.6666), 'coin2': tensor(0.0072)}
{'coin1': tensor(-0.9747), 'coin2': tensor(-0.3632)}
{'coin1': tensor(-1.0132), 'coin2': tensor(-0.5673)}
{'coin1': tensor(-0.8484), 'coin2': tensor(-0.5486)}
{'coin1'

Warmup:   1%|▎                                                         | 12/2000 [00:00, 14.82it/s, step size=1.00e-01]

{'coin1': tensor(0.5178), 'coin2': tensor(0.6651)}
{'coin1': tensor(1.0733), 'coin2': tensor(0.5369)}
{'coin1': tensor(0.5004), 'coin2': tensor(0.7486)}
{'coin1': tensor(0.1355), 'coin2': tensor(0.7312)}
{'coin1': tensor(-0.2121), 'coin2': tensor(0.4553)}
{'coin1': tensor(-0.1681), 'coin2': tensor(0.4443)}
{'coin1': tensor(-0.1249), 'coin2': tensor(0.4568)}
{'coin1': tensor(-0.5709), 'coin2': tensor(0.3144)}
{'coin1': tensor(-0.8144), 'coin2': tensor(0.5708)}
{'coin1': tensor(-0.6223), 'coin2': tensor(0.5379)}
{'coin1': tensor(-0.9789), 'coin2': tensor(0.4524)}
{'coin1': tensor(-1.1974), 'coin2': tensor(0.6025)}
{'coin1': tensor(-1.1945), 'coin2': tensor(0.3094)}
{'coin1': tensor(-0.9031), 'coin2': tensor(0.3635)}
{'coin1': tensor(-0.7777), 'coin2': tensor(-0.0139)}
{'coin1': tensor(-0.9543), 'coin2': tensor(0.0281)}
{'coin1': tensor(0.9573), 'coin2': tensor(1.5849)}
{'coin1': tensor(1.5679), 'coin2': tensor(1.3040)}
{'coin1': tensor(1.7556), 'coin2': tensor(1.4346)}
{'coin1': tensor(2

Warmup:   1%|▍                                                         | 14/2000 [00:00, 14.39it/s, step size=1.00e-01]

{'coin1': tensor(0.7499), 'coin2': tensor(-0.9424)}
{'coin1': tensor(0.8711), 'coin2': tensor(-0.6395)}
{'coin1': tensor(0.8468), 'coin2': tensor(-0.3640)}
{'coin1': tensor(1.3432), 'coin2': tensor(-0.0888)}
{'coin1': tensor(1.7889), 'coin2': tensor(-0.1881)}
{'coin1': tensor(2.2106), 'coin2': tensor(-0.0487)}
{'coin1': tensor(2.2555), 'coin2': tensor(-0.1938)}
{'coin1': tensor(2.1818), 'coin2': tensor(0.0678)}
{'coin1': tensor(2.2090), 'coin2': tensor(0.0741)}
{'coin1': tensor(2.0806), 'coin2': tensor(0.2430)}
{'coin1': tensor(1.7463), 'coin2': tensor(0.0223)}
{'coin1': tensor(1.4674), 'coin2': tensor(-0.1405)}
{'coin1': tensor(1.3573), 'coin2': tensor(0.2151)}
{'coin1': tensor(-1.6080), 'coin2': tensor(0.9283)}
{'coin1': tensor(-1.7007), 'coin2': tensor(0.4602)}
{'coin1': tensor(-1.4405), 'coin2': tensor(0.3828)}
{'coin1': tensor(-0.8937), 'coin2': tensor(0.5593)}
{'coin1': tensor(-0.3892), 'coin2': tensor(0.8641)}
{'coin1': tensor(-0.3495), 'coin2': tensor(0.6266)}
{'coin1': tensor(

Warmup:   1%|▍                                                         | 17/2000 [00:01, 15.12it/s, step size=1.00e-01]

{'coin1': tensor(-0.0634), 'coin2': tensor(0.6450)}
{'coin1': tensor(-0.6803), 'coin2': tensor(0.6763)}
{'coin1': tensor(-1.1013), 'coin2': tensor(0.7615)}
{'coin1': tensor(-1.4956), 'coin2': tensor(0.8801)}
{'coin1': tensor(-1.4063), 'coin2': tensor(0.8080)}
{'coin1': tensor(-1.5521), 'coin2': tensor(0.8460)}
{'coin1': tensor(-1.7651), 'coin2': tensor(1.0499)}
{'coin1': tensor(-1.7214), 'coin2': tensor(0.8057)}
{'coin1': tensor(1.3249), 'coin2': tensor(0.6937)}
{'coin1': tensor(1.1750), 'coin2': tensor(0.5877)}
{'coin1': tensor(1.1722), 'coin2': tensor(0.6705)}
{'coin1': tensor(1.1183), 'coin2': tensor(0.8023)}
{'coin1': tensor(1.1516), 'coin2': tensor(0.9093)}
{'coin1': tensor(0.9935), 'coin2': tensor(1.3739)}
{'coin1': tensor(1.2578), 'coin2': tensor(1.2075)}
{'coin1': tensor(1.0260), 'coin2': tensor(1.0503)}
{'coin1': tensor(0.9259), 'coin2': tensor(1.1474)}
{'coin1': tensor(0.9022), 'coin2': tensor(1.1803)}
{'coin1': tensor(0.6088), 'coin2': tensor(0.9774)}
{'coin1': tensor(0.1525

Warmup:   1%|▌                                                         | 19/2000 [00:01, 13.87it/s, step size=1.00e-01]

{'coin1': tensor(0.7866), 'coin2': tensor(0.6306)}
{'coin1': tensor(0.3085), 'coin2': tensor(0.2614)}
{'coin1': tensor(1.6171), 'coin2': tensor(1.4829)}
{'coin1': tensor(1.3942), 'coin2': tensor(1.5973)}
{'coin1': tensor(1.1168), 'coin2': tensor(1.6007)}
{'coin1': tensor(1.3212), 'coin2': tensor(1.8137)}
{'coin1': tensor(1.6977), 'coin2': tensor(1.4967)}
{'coin1': tensor(1.6361), 'coin2': tensor(1.0826)}
{'coin1': tensor(1.4062), 'coin2': tensor(0.9357)}
{'coin1': tensor(1.1671), 'coin2': tensor(0.7334)}
{'coin1': tensor(0.8453), 'coin2': tensor(0.8286)}
{'coin1': tensor(0.6213), 'coin2': tensor(0.9044)}
{'coin1': tensor(0.6063), 'coin2': tensor(1.1532)}
{'coin1': tensor(0.0965), 'coin2': tensor(1.1142)}
{'coin1': tensor(0.1686), 'coin2': tensor(0.8766)}
{'coin1': tensor(-0.2620), 'coin2': tensor(0.9063)}
{'coin1': tensor(-0.3777), 'coin2': tensor(0.8924)}
{'coin1': tensor(-0.6232), 'coin2': tensor(1.3021)}
{'coin1': tensor(-0.6389), 'coin2': tensor(0.9921)}
{'coin1': tensor(-0.6077), 

Warmup:   1%|▌                                                         | 21/2000 [00:01, 13.47it/s, step size=1.00e-01]

{'coin1': tensor(0.8115), 'coin2': tensor(0.6791)}
{'coin1': tensor(1.1012), 'coin2': tensor(0.5827)}
{'coin1': tensor(1.2681), 'coin2': tensor(0.5609)}
{'coin1': tensor(1.4468), 'coin2': tensor(0.7656)}
{'coin1': tensor(1.7139), 'coin2': tensor(0.8352)}
{'coin1': tensor(1.9895), 'coin2': tensor(0.9584)}
{'coin1': tensor(1.9790), 'coin2': tensor(0.8872)}
{'coin1': tensor(2.2216), 'coin2': tensor(0.8619)}
{'coin1': tensor(2.2746), 'coin2': tensor(0.7527)}
{'coin1': tensor(1.7815), 'coin2': tensor(0.9495)}
{'coin1': tensor(1.6948), 'coin2': tensor(0.3736)}
{'coin1': tensor(1.4637), 'coin2': tensor(0.3836)}
{'coin1': tensor(1.0558), 'coin2': tensor(0.2495)}
{'coin1': tensor(0.9928), 'coin2': tensor(0.1802)}
{'coin1': tensor(0.5221), 'coin2': tensor(0.0950)}
{'coin1': tensor(-0.0172), 'coin2': tensor(0.8849)}
{'coin1': tensor(-0.3290), 'coin2': tensor(0.7808)}
{'coin1': tensor(-0.4803), 'coin2': tensor(0.7166)}
{'coin1': tensor(0.1613), 'coin2': tensor(0.9881)}
{'coin1': tensor(0.5507), 'c

Warmup:   1%|▋                                                         | 23/2000 [00:01, 13.86it/s, step size=1.00e-01]

{'coin1': tensor(0.6870), 'coin2': tensor(-0.6685)}
{'coin1': tensor(0.9798), 'coin2': tensor(-0.9134)}
{'coin1': tensor(0.7650), 'coin2': tensor(-1.1820)}
{'coin1': tensor(0.5710), 'coin2': tensor(-1.5194)}
{'coin1': tensor(0.3187), 'coin2': tensor(-1.3588)}
{'coin1': tensor(0.0037), 'coin2': tensor(-1.4186)}
{'coin1': tensor(0.1112), 'coin2': tensor(-1.7229)}
{'coin1': tensor(0.1818), 'coin2': tensor(-1.6841)}
{'coin1': tensor(-0.0073), 'coin2': tensor(-1.6092)}
{'coin1': tensor(-0.3197), 'coin2': tensor(-1.0703)}
{'coin1': tensor(0.2977), 'coin2': tensor(0.2657)}
{'coin1': tensor(0.0611), 'coin2': tensor(0.2963)}
{'coin1': tensor(0.2188), 'coin2': tensor(0.2535)}
{'coin1': tensor(0.2768), 'coin2': tensor(0.5208)}
{'coin1': tensor(0.0874), 'coin2': tensor(0.7251)}
{'coin1': tensor(-0.1780), 'coin2': tensor(0.5522)}
{'coin1': tensor(-0.4874), 'coin2': tensor(0.8193)}
{'coin1': tensor(-0.7988), 'coin2': tensor(0.9562)}
{'coin1': tensor(-0.9604), 'coin2': tensor(1.0970)}
{'coin1': tenso

Warmup:   1%|▋                                                         | 25/2000 [00:01, 13.65it/s, step size=1.00e-01]

{'coin1': tensor(0.7539), 'coin2': tensor(-0.2693)}
{'coin1': tensor(-0.2465), 'coin2': tensor(0.9252)}
{'coin1': tensor(0.0283), 'coin2': tensor(0.5589)}
{'coin1': tensor(0.0575), 'coin2': tensor(0.2398)}
{'coin1': tensor(0.2433), 'coin2': tensor(0.2701)}
{'coin1': tensor(0.2386), 'coin2': tensor(0.2413)}
{'coin1': tensor(0.2664), 'coin2': tensor(-0.0628)}
{'coin1': tensor(0.1980), 'coin2': tensor(-0.4040)}
{'coin1': tensor(-0.4521), 'coin2': tensor(0.4739)}
{'coin1': tensor(-0.7227), 'coin2': tensor(0.5034)}
{'coin1': tensor(-0.9681), 'coin2': tensor(0.4069)}
{'coin1': tensor(0.1021), 'coin2': tensor(-0.2890)}
{'coin1': tensor(0.2348), 'coin2': tensor(-0.2412)}
{'coin1': tensor(0.3566), 'coin2': tensor(-0.2210)}
{'coin1': tensor(0.2432), 'coin2': tensor(-0.2991)}
{'coin1': tensor(1.2943), 'coin2': tensor(-0.3807)}
{'coin1': tensor(1.3144), 'coin2': tensor(-0.2747)}
{'coin1': tensor(1.7039), 'coin2': tensor(-0.2244)}
{'coin1': tensor(1.6530), 'coin2': tensor(-0.3419)}
{'coin1': tensor

Warmup:   2%|▊                                                         | 30/2000 [00:02, 14.53it/s, step size=1.00e-01]

{'coin1': tensor(-0.3080), 'coin2': tensor(-0.5788)}
{'coin1': tensor(-0.2395), 'coin2': tensor(-0.5661)}
{'coin1': tensor(-0.4054), 'coin2': tensor(-0.0718)}
{'coin1': tensor(-0.1566), 'coin2': tensor(-0.3361)}
{'coin1': tensor(0.0712), 'coin2': tensor(-0.5942)}
{'coin1': tensor(0.1892), 'coin2': tensor(-0.7291)}
{'coin1': tensor(0.3269), 'coin2': tensor(-0.8399)}
{'coin1': tensor(0.6194), 'coin2': tensor(-0.6060)}
{'coin1': tensor(1.1149), 'coin2': tensor(-0.2498)}
{'coin1': tensor(1.4609), 'coin2': tensor(-0.3755)}
{'coin1': tensor(1.6734), 'coin2': tensor(-0.4667)}
{'coin1': tensor(1.3568), 'coin2': tensor(-0.5934)}
{'coin1': tensor(1.5260), 'coin2': tensor(-0.6004)}
{'coin1': tensor(1.3761), 'coin2': tensor(-0.4737)}
{'coin1': tensor(1.2850), 'coin2': tensor(-0.3093)}
{'coin1': tensor(1.0410), 'coin2': tensor(-0.2899)}
{'coin1': tensor(1.1809), 'coin2': tensor(0.0296)}
{'coin1': tensor(0.0032), 'coin2': tensor(-1.2005)}
{'coin1': tensor(1.2270), 'coin2': tensor(-0.5576)}
{'coin1':

Warmup:   2%|▉                                                         | 33/2000 [00:02, 16.31it/s, step size=1.00e-01]

{'coin1': tensor(0.4695), 'coin2': tensor(0.3991)}
{'coin1': tensor(0.5936), 'coin2': tensor(-0.0261)}
{'coin1': tensor(0.4835), 'coin2': tensor(0.1783)}
{'coin1': tensor(0.4645), 'coin2': tensor(0.0863)}
{'coin1': tensor(-0.0917), 'coin2': tensor(0.0748)}
{'coin1': tensor(0.0554), 'coin2': tensor(-0.0428)}
{'coin1': tensor(1.1655), 'coin2': tensor(0.3106)}
{'coin1': tensor(1.3659), 'coin2': tensor(0.3995)}
{'coin1': tensor(-0.3012), 'coin2': tensor(-0.3002)}
{'coin1': tensor(-1.0253), 'coin2': tensor(-0.4336)}
{'coin1': tensor(-1.2260), 'coin2': tensor(-0.2781)}
{'coin1': tensor(-1.4858), 'coin2': tensor(-0.3412)}
{'coin1': tensor(-0.5145), 'coin2': tensor(-1.8027)}
{'coin1': tensor(-0.3354), 'coin2': tensor(-1.0822)}
{'coin1': tensor(-0.1003), 'coin2': tensor(-0.6326)}
{'coin1': tensor(-0.0648), 'coin2': tensor(-1.9413)}
{'coin1': tensor(0.4135), 'coin2': tensor(-1.1934)}
{'coin1': tensor(0.2194), 'coin2': tensor(-1.0078)}
{'coin1': tensor(0.3809), 'coin2': tensor(-0.6278)}
{'coin1':

Warmup:   2%|█                                                         | 35/2000 [00:02, 16.05it/s, step size=1.00e-01]

{'coin1': tensor(-0.6687), 'coin2': tensor(0.8592)}
{'coin1': tensor(-0.6015), 'coin2': tensor(0.8963)}
{'coin1': tensor(-0.6330), 'coin2': tensor(0.8050)}
{'coin1': tensor(-0.3658), 'coin2': tensor(0.9013)}
{'coin1': tensor(-0.3216), 'coin2': tensor(0.5841)}
{'coin1': tensor(0.2411), 'coin2': tensor(0.5455)}
{'coin1': tensor(0.5796), 'coin2': tensor(0.1209)}
{'coin1': tensor(1.2372), 'coin2': tensor(-0.1176)}
{'coin1': tensor(1.4664), 'coin2': tensor(-0.1863)}
{'coin1': tensor(1.0098), 'coin2': tensor(0.6809)}
{'coin1': tensor(1.1010), 'coin2': tensor(0.9340)}
{'coin1': tensor(1.4058), 'coin2': tensor(1.0591)}
{'coin1': tensor(1.3410), 'coin2': tensor(1.1180)}
{'coin1': tensor(1.1079), 'coin2': tensor(0.7698)}
{'coin1': tensor(0.9802), 'coin2': tensor(0.5260)}
{'coin1': tensor(1.3226), 'coin2': tensor(0.5890)}
{'coin1': tensor(0.4654), 'coin2': tensor(0.8821)}
{'coin1': tensor(0.1293), 'coin2': tensor(0.9300)}
{'coin1': tensor(-0.3216), 'coin2': tensor(0.8891)}
{'coin1': tensor(-0.787

Warmup:   2%|█                                                         | 37/2000 [00:02, 13.80it/s, step size=1.00e-01]

{'coin1': tensor(0.7333), 'coin2': tensor(-0.2386)}
{'coin1': tensor(0.1800), 'coin2': tensor(-0.2264)}
{'coin1': tensor(0.1063), 'coin2': tensor(-0.6081)}
{'coin1': tensor(-0.0769), 'coin2': tensor(-0.5141)}
{'coin1': tensor(-0.3045), 'coin2': tensor(-0.4014)}
{'coin1': tensor(-0.4818), 'coin2': tensor(-0.6691)}
{'coin1': tensor(-0.8410), 'coin2': tensor(-0.5275)}
{'coin1': tensor(-0.7629), 'coin2': tensor(-0.4695)}
{'coin1': tensor(-0.9263), 'coin2': tensor(-0.9291)}
{'coin1': tensor(-1.1044), 'coin2': tensor(-0.5761)}
{'coin1': tensor(-0.9979), 'coin2': tensor(-0.6710)}
{'coin1': tensor(-1.0138), 'coin2': tensor(-0.5199)}
{'coin1': tensor(-0.9044), 'coin2': tensor(-0.4245)}
{'coin1': tensor(-0.9028), 'coin2': tensor(0.2299)}
{'coin1': tensor(-0.3411), 'coin2': tensor(0.5714)}
{'coin1': tensor(-0.2391), 'coin2': tensor(0.7366)}
{'coin1': tensor(0.2441), 'coin2': tensor(0.8713)}
{'coin1': tensor(0.9579), 'coin2': tensor(0.6866)}
{'coin1': tensor(1.1836), 'coin2': tensor(0.7419)}
{'coi

Warmup:   2%|█▏                                                        | 41/2000 [00:02, 17.08it/s, step size=1.00e-01]

{'coin1': tensor(-0.3962), 'coin2': tensor(-0.1885)}
{'coin1': tensor(-0.4032), 'coin2': tensor(-0.3210)}
{'coin1': tensor(-0.7504), 'coin2': tensor(-0.2812)}
{'coin1': tensor(-0.8801), 'coin2': tensor(-0.4039)}
{'coin1': tensor(-0.8565), 'coin2': tensor(-0.7263)}
{'coin1': tensor(-0.6519), 'coin2': tensor(-0.6976)}
{'coin1': tensor(-0.8367), 'coin2': tensor(-0.6968)}
{'coin1': tensor(-0.0685), 'coin2': tensor(0.9120)}
{'coin1': tensor(0.2327), 'coin2': tensor(0.8764)}
{'coin1': tensor(-0.0418), 'coin2': tensor(0.8807)}
{'coin1': tensor(0.0644), 'coin2': tensor(0.9525)}
{'coin1': tensor(0.5197), 'coin2': tensor(0.7101)}
{'coin1': tensor(0.8730), 'coin2': tensor(0.5423)}
{'coin1': tensor(1.2217), 'coin2': tensor(0.4578)}
{'coin1': tensor(1.3021), 'coin2': tensor(0.3453)}
{'coin1': tensor(1.0464), 'coin2': tensor(0.2253)}
{'coin1': tensor(0.3638), 'coin2': tensor(0.0980)}
{'coin1': tensor(0.6759), 'coin2': tensor(0.1888)}
{'coin1': tensor(0.5663), 'coin2': tensor(0.2891)}
{'coin1': tenso

Warmup:   2%|█▏                                                        | 43/2000 [00:02, 14.08it/s, step size=1.00e-01]

{'coin1': tensor(0.1649), 'coin2': tensor(0.9477)}
{'coin1': tensor(-0.2011), 'coin2': tensor(0.8363)}
{'coin1': tensor(-0.3434), 'coin2': tensor(0.8741)}
{'coin1': tensor(-0.5256), 'coin2': tensor(1.0318)}
{'coin1': tensor(-0.6540), 'coin2': tensor(0.7140)}
{'coin1': tensor(-1.1418), 'coin2': tensor(0.8004)}
{'coin1': tensor(-1.0009), 'coin2': tensor(0.6321)}
{'coin1': tensor(-0.5382), 'coin2': tensor(0.2305)}
{'coin1': tensor(-0.4045), 'coin2': tensor(0.1216)}
{'coin1': tensor(-0.5732), 'coin2': tensor(0.2105)}
{'coin1': tensor(-0.2583), 'coin2': tensor(0.3198)}
{'coin1': tensor(0.0579), 'coin2': tensor(0.1203)}
{'coin1': tensor(-0.1682), 'coin2': tensor(0.0751)}
{'coin1': tensor(0.0051), 'coin2': tensor(0.1141)}
{'coin1': tensor(0.3031), 'coin2': tensor(0.1232)}
{'coin1': tensor(1.5971), 'coin2': tensor(0.0571)}
{'coin1': tensor(2.0313), 'coin2': tensor(0.3244)}
{'coin1': tensor(2.0514), 'coin2': tensor(0.9782)}
{'coin1': tensor(1.9287), 'coin2': tensor(1.0975)}
{'coin1': tensor(1.8

Warmup:   2%|█▎                                                        | 45/2000 [00:03, 12.98it/s, step size=1.00e-01]

{'coin1': tensor(-1.5103), 'coin2': tensor(-1.0425)}
{'coin1': tensor(1.2568), 'coin2': tensor(1.3614)}
{'coin1': tensor(1.7091), 'coin2': tensor(1.3772)}
{'coin1': tensor(1.9011), 'coin2': tensor(1.3218)}
{'coin1': tensor(2.3619), 'coin2': tensor(1.1937)}
{'coin1': tensor(2.6409), 'coin2': tensor(1.2361)}
{'coin1': tensor(2.8575), 'coin2': tensor(1.1317)}
{'coin1': tensor(2.6993), 'coin2': tensor(0.9352)}
{'coin1': tensor(2.7279), 'coin2': tensor(0.3171)}
{'coin1': tensor(2.2175), 'coin2': tensor(0.4454)}
{'coin1': tensor(1.7943), 'coin2': tensor(0.4236)}
{'coin1': tensor(1.0872), 'coin2': tensor(0.5189)}
{'coin1': tensor(0.7747), 'coin2': tensor(0.2986)}
{'coin1': tensor(0.6062), 'coin2': tensor(0.3780)}
{'coin1': tensor(0.2709), 'coin2': tensor(0.2714)}
{'coin1': tensor(0.0747), 'coin2': tensor(-0.2062)}
{'coin1': tensor(-2.7958), 'coin2': tensor(-0.8903)}
{'coin1': tensor(-1.9895), 'coin2': tensor(-0.8992)}
{'coin1': tensor(-1.6439), 'coin2': tensor(-0.9396)}
{'coin1': tensor(-2.78

KeyboardInterrupt: 

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(6, 6), sharex=True)

def analytic_posterior(x, alpha, beta):
    fn = Beta(alpha, beta)
    return np.exp(fn.log_prob(x))
    
x = torch.tensor(np.linspace(0, 1, num=1000))

with torch.no_grad():
    axs[0].hist(sgnuts_samples["coin1"].numpy(), density=True, bins=30)
    axs[0].set(title="SGNUTS 'coin1' samples")
    axs[0].plot(x, analytic_posterior(x, 5., 7.), "r", label="analytic posterior")
    axs[1].hist(sgnuts_samples["coin2"].numpy(), density=True, bins=30)
    axs[1].set(title="SGNUTS 'coin2' samples")
    axs[1].plot(x, analytic_posterior(x, 9., 3.), "r", label="analytic posterior")
plt.xlim((0.0, 1.0))
plt.show()