In [None]:
from jax import numpy as jnp, random, ops, lax

from matplotlib import pyplot as plt
from tools import obj_dic, show_heatmap_contours

import numpyro
from numpyro import distributions as dist, sample, plate, param
from numpyro.infer import autoguide


In [None]:
print(numpyro.__version__)
#numpyro.enable_validation(True)

In [None]:

TRAIN_K = 1412

def gen_data(r, N):
    rk = random.split(r, 4)
    a = random.normal(rk[0], (1,2))*30
    b = random.normal(rk[1], (1,2))*30
    u = random.uniform(rk[2], (N,1))
    p = a + (b-a) * u + random.normal(rk[3], (N, 2))
    return p, obj_dic(locals())

data, gt = gen_data(random.PRNGKey(TRAIN_K), 1000)

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')


In [None]:
def true_contours():
    data, gt = gen_data(random.PRNGKey(TRAIN_K), 50000)
    show_heatmap_contours(data[:,0], data[:,1], bins=100)
    plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
    plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')
    
true_contours()

In [None]:
SUB = (data.shape[0]*1) # // 10 # for minibatching (that does not work properly...)

In [None]:
# Generative model/story, used (as p(x|θ)) for variational inference and for MCMC
def model(data, sub=SUB, with_obs=True):
    N = data.shape[0]
    a = sample('a', dist.MultivariateNormal(jnp.zeros((2,)), jnp.eye(2)*100**2))
    b = sample('b', dist.MultivariateNormal(jnp.zeros((2,)), jnp.eye(2)*100**2))

    with plate('data', N, subsample_size=sub) as ind:
        u = sample('u', dist.Uniform(0, 1))
        obs = sample('obs', dist.MultivariateNormal(a + (b-a)*u[:,None], jnp.eye(2)), obs=data[ind,:] if with_obs else None)
        return obs

    
# This guide function is the variational distribution (definition of the approximating q(θ))
def guide(data):
    N = data.shape[0]
    # not so bad init
    aμinit = jnp.min(data, 0)
    bμinit = jnp.max(data, 0)
    med = jnp.median(data, 0)
    aμinit = aμinit + random.uniform(random.PRNGKey(201), shape=(2,)) * (med - aμinit)
    bμinit = bμinit + random.uniform(random.PRNGKey(202), shape=(2,)) * (med - bμinit)

    aμ = param('qaμ', aμinit)
    aσ = param('qaσ', 0.1, constraint=dist.constraints.positive)
    bμ = param('qbμ', bμinit)
    bσ = param('qbσ', 0.1, constraint=dist.constraints.positive)
    a = sample('a', dist.MultivariateNormal(aμ, jnp.eye(2)*aσ))
    b = sample('b', dist.MultivariateNormal(bμ, jnp.eye(2)*bσ))

    uα = param('uα', jnp.zeros(N)+1, constraint=dist.constraints.positive)
    uβ = param('uβ', jnp.zeros(N)+1, constraint=dist.constraints.positive)
    #uμ = param('uμ', jnp.ones(N)*0.5, constraint=dist.constraints.interval(0, 1))
    #uσ = param('uσ', jnp.ones(N)*0.1, constraint=dist.constraints.positive)

#    for i in range(N):
#        u = pyro.sample('u_{}'.format(i), dist.Beta(1+uα[i], 1+uβ[i]))

#    for i in pyro.plate('data', N):
#        u = pyro.sample('u_{}'.format(i), dist.Beta(1+uα[i], 1+uβ[i]))

    with plate('data', N, subsample_size=SUB) as ind:
        u = sample('u', dist.Beta(1+uα[ind], 1+uβ[ind]))
        #uu = sample('uu', dist.Normal(uμ[ind], uσ[ind]), infer={'is_auxiliary': True})
        #u = sample('u', dist.Delta(uu)) #dist.Delta(jnp.clamp(uu, 0, 1)))

history = []


##### alternative configuration
auto_guide = False

In [None]:
lr = 0.05

if auto_guide:
    lr = 0.05
    guide = autoguide.AutoDiagonalNormal(model)
    #guide =  autoguide.AutoLowRankMultivariateNormal(model, rank=20)
    #elbo = autoguide.AutoContinuousELBO()
    elbo = numpyro.infer.Trace_ELBO()#num_particles=1, max_plate_nesting=1)
else:
    elbo = numpyro.infer.Trace_ELBO()#num_particles=1, max_plate_nesting=1)

optimizer = numpyro.optim.Adam(step_size=lr)#{'lr': lr, 'betas': [0.9, 0.99]})

svi = numpyro.infer.SVI(model, guide, optimizer, loss=elbo)

In [None]:
### %%time

n_steps = 2000
init_state = svi.init(random.PRNGKey(42000), data)

#print(init_state)

def scanner(pstate, i):
    state, loss = svi.update(pstate, data)
    return state, (loss, svi.get_params(state))

state, (losses, params) = lax.scan(scanner, init_state, jnp.arange(n_steps))


In [None]:
plt.plot(losses)
print(losses[-10:])

In [None]:
if auto_guide:
    print(params['auto_loc'].shape)
else:
    print(params['qaμ'].shape)
    print(jnp.hstack([params['qaμ'], params['qbμ']]).shape)


In [None]:
import numpy

if auto_guide:
    history = params['auto_loc'][:,:4]
else:
    history = numpy.hstack([params['qaμ'], params['qbμ']])

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')

h = numpy.array(history[-2000:])
print(h.shape)
plt.plot(h[:,0], h[:,1], label='Mean a accross iterations')
plt.plot(h[:,2], h[:,3], label='Mean b accross iterations')
plt.scatter(h[-1,0], h[-1,1])
plt.scatter(h[-1,2], h[-1,3])
plt.plot([h[-1,0], h[-1,2]], [h[-1,1], h[-1,3]], '--', label="Final a--b")
plt.legend()

plt.show()


In [None]:

plot_posterior_predictive = False

if plot_posterior_predictive:
    
    pred = numpyro.handlers.seed(model, random.PRNGKey(4242))
    pred = numpyro.handlers.condition(pred,
                                      dict(
                                          a=jnp.array([h[-1,0], h[-1,1]]),
                                          b=jnp.array([h[-1,2], h[-1,3]])))
    
    N = 50000
    samples = pred(jnp.zeros((N,2)), sub=N, with_obs=False)
    print(samples.shape)
    plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
    show_heatmap_contours(samples[:,0], samples[:,1], bins=100)
    plt.plot([h[-1,0], h[-1,2]], [h[-1,1], h[-1,3]], 'r--', label="Final a--b")
    


NB: you can go back above and enable/disable the auto-guide and retry

# Now with MCMC

In [None]:
from numpyro.infer.mcmc import MCMC
from numpyro.infer.hmc import NUTS, HMC


In [None]:
mcmodel = HMC(model)
mcmc = MCMC(mcmodel, num_warmup=500, num_samples=1000, num_chains=10)
mcmc.run(random.PRNGKey(42043), data)
#mcmc.print_summary()

In [None]:
# As a and b can be equivalently swapped, the MCMC sampler generates 'a' samples at both configurations
plt.scatter(mcmc.get_samples()['a'][:,0], mcmc.get_samples()['a'][:,1])
plt.xlabel('$a_x$')
plt.ylabel('$a_y$')
plt.show()


In [None]:
# Let's see if 'a' and 'b' are always coherent
plt.scatter(mcmc.get_samples()['a'][:,0], mcmc.get_samples()['b'][:,0])
plt.xlabel('$a_x$')
plt.ylabel('$b_x$')
plt.show()

In [None]:
mcmc.get_samples()

In [None]:
import numpy
for i in range(100):
    u = mcmc.get_samples()['u']
    subu = u[-100:,i]
    plt.scatter(subu, u[-1,i]+numpy.random.uniform(0, 1, subu.shape)/10000, marker='.', alpha=0.01)