In [None]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr

from pymc import HalfCauchy, Model, Normal, MvNormal, Uniform, Beta, sample

print(f"Running on PyMC v{pm.__version__}")

# pip install blackjax numpyro
# ... to make it use jax for NUTS (MCMC)

import pymc.sampling.jax

def backward_simplex(value):
    value = np.concatenate([value, -np.sum(value, -1, keepdims=True)], axis=-1)
    exp_value_max = np.exp(value - np.max(value, -1, keepdims=True))
    return exp_value_max / np.sum(exp_value_max, -1, keepdims=True)


In [None]:
from tools import obj_dic, show_heatmap_contours, show_heatmap

SEED = 1412
#TRAIN_K = 1412

def gen_data(N, seed=SEED):
    r = np.random.default_rng(seed)
    a = r.normal(0, 30, (1,2))
    b = r.normal(0, 30, (1,2))
    u = r.uniform(0, 1, (N,1))
    p = a + (b-a) * u + r.normal(0, 1, (N, 2))
    return p, obj_dic(locals())

data, gt = gen_data(100)

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')


In [None]:
def true_contours():
    data, gt = gen_data(50000)
    show_heatmap_contours(data[:,0], data[:,1], bins=30)
    plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
    plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')
    
true_contours()

In [None]:
with Model(coords={"points": list(range(data.shape[0]))}) as model:
    a = MvNormal("a", mu=np.zeros(2)-1, cov=np.eye(2)*100**2)
    b = MvNormal("b", mu=np.zeros(2)+1, cov=np.eye(2)*100**2)
    z = Beta("z", alpha=1, beta=1, dims="points")
    # reshaping in terms of pm dimensions (the last dim is actually the 2D space by itself)
    likelihood = MvNormal("x", mu=a + (b-a)*z[...,None], cov=np.eye(2)*1, observed=data)
    idata = sample(1000)

In [None]:
idata

In [None]:
for rv, shape in model.eval_rv_shapes().items():
    print(f"{rv:>11}: shape={shape}")
pm.model_to_graphviz(model)

In [None]:
az.plot_posterior(idata, var_names=["a", "b"])

In [None]:
az.plot_trace(idata)

### Exploring ADVI (Automatic Differentiation Variational Inference)

In [None]:
with model:
    approx = pm.fit(n=20000, obj_optimizer=pm.adam(learning_rate=0.1))

In [None]:
approx.mean.eval()[:4]

### Trying blackjax and numpyro

In [None]:

with model:
    idata_blackjax = pm.sampling.jax.sample_blackjax_nuts(1000)

In [None]:
az.plot_trace(idata_blackjax)

In [None]:
with model:
    idata_numpyro = pm.sampling.jax.sample_numpyro_nuts(1000)

In [None]:
az.plot_trace(idata_numpyro, var_names="a b".split(" "))

# Now with a mixture of noisy segments

In [None]:

SEED = 1412

def gen_data(N, p, seed=SEED):
    K = len(p)
    r = np.random.default_rng(seed)
    a = r.normal(0, 30, (K,2))
    b = r.normal(0, 30, (K,2))
    u = r.uniform(0, 1, (N,1))
    z = r.choice(range(K), p=p, size=(N,))
    p = a[z,:] + (b-a)[z,:] * u + r.normal(0, 1, (N, 2))
    return p, obj_dic(locals())

data, gt = gen_data(100, [0.3, 0.7])

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
plt.scatter(gt.a[0,0], gt.a[0,1], marker='+')
plt.scatter(gt.b[0,0], gt.b[0,1], marker='+')
plt.scatter(gt.a[1,0], gt.a[1,1], marker='o')
plt.scatter(gt.b[1,0], gt.b[1,1], marker='o')

In [None]:
FULL_MANUAL = False
K = 10
with Model(coords={"segments": list(range(K)), "points": range(data.shape[0])}) as model:
    a = MvNormal("a", mu=np.zeros(2)-1, cov=np.eye(2)*100, dims="segments", shape=(K, 2))
    b = MvNormal("b", mu=np.zeros(2)+1, cov=np.eye(2)*100, dims="segments", shape=(K, 2))
    π = pm.Dirichlet("π", a=[1]*K)
    u = Uniform("u", dims="points")
    if FULL_MANUAL:
        # To show the fact that we can do it but it won't play nicely with ADVI etc
        z = pm.Categorical("z", p=π, dims="points")
        # reshaping in terms of pm dimensions (the last dim is actually the 2D space by itself)
        likelihood = MvNormal("x", mu=a[z,...] + (b-a)[z,...]*u[...,None], cov=np.eye(2)*1, observed=data)
    else:
        components = [
            MvNormal.dist(mu=a[k,...] + (b-a)[k,...]*u[...,None], cov=np.eye(2)*1)
            for k in range(K)
        ]
        #components = MvNormal.dist(
        #    mu=pm.math.stack([a[k,...] + (b-a)[k,...]*u[...,None] for k in range(K)]),
        #    cov=np.eye(2)*1, shape=(K,2))
        likelihood = pm.Mixture("x", w=π, comp_dists=components, observed=data)    

In [None]:
# SLOW
with model:
    idata = sample(1000)


In [None]:
# Fast but might block in sampling or after (for big problems)
with model:
    idata_blackjax = pm.sampling.jax.sample_blackjax_nuts(10000, chains=4)

In [None]:
# Less fast (but more robust?)
with model:
    idata_numpyro = pm.sampling.jax.sample_numpyro_nuts(1000, chains=6)

In [None]:
for rv, shape in model.eval_rv_shapes().items():
    print(f"{rv:>11}: shape={shape}")
    
pm.model_to_graphviz(model)

In [None]:
az.plot_posterior(idata_numpyro, var_names=["π", "a", "b"])

In [None]:
az.plot_trace(idata_numpyro)

# ADVI (Automatic Differentiation Variational Inference)

> We will get an error if FULL_MANUAL is True, discrete variables are not differentiable.
> We could write the downward likelihood of the mixture as in
> https://www.pymc.io/projects/examples/en/latest/variational_inference/gaussian-mixture-model-advi.html?highlight=does%20not%20fit%20advi but it requires some geometric reasoning


In [None]:
with model:
    approx = pm.fit(n=7500, obj_optimizer=pm.adam(learning_rate=1e-1))

In [None]:
approx.mean.eval().shape, approx.mean.eval()[:14]

In [None]:
plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
m = approx.mean.eval()
pi = backward_simplex(m[4*K:5*K-1])
for k in range(K):
    plt.plot(m[[2*k, 2*K+2*k]], m[[2*k+1, 2*K+2*k+1]], alpha=pi[k]/np.max(pi))
plt.show()
plt.bar(x=range(K), height=backward_simplex(m[4*K:5*K-1]))

# Circular dataset

In [None]:

SEED = 1412

def gen_data_ring(N, seed=SEED):
    r = np.random.default_rng(seed)
    x = r.normal(0, 1, (N, 2))
    x = 20 * x / np.sum(x**2, axis=-1, keepdims=True)**0.5
    p = x * r.normal(1, .05, (N, 1))
    return p, obj_dic(locals())

data, gt = gen_data_ring(300)

plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)


In [None]:
K = 20
with Model(coords={"segments": list(range(K)), "points": list(range(data.shape[0]))}) as model:
    a = MvNormal("a", mu=np.zeros(2)-1, cov=np.eye(2)*40**2, dims="segments", shape=(K, 2))
    b = MvNormal("b", mu=np.zeros(2)+1, cov=np.eye(2)*40**2, dims="segments", shape=(K, 2))
    π = pm.Dirichlet("π", a=[1]*K)
    u = Uniform("u", dims="points")
    components = [
        MvNormal.dist(mu=a[k,...] + (b-a)[k,...]*u[...,None], cov=np.eye(2)*1)
        for k in range(K)
    ]
    likelihood = pm.Mixture("x", w=π, comp_dists=components, observed=data)

In [None]:
with model:
    approx = pm.fit(n=10000, obj_optimizer=pm.adam(learning_rate=1e-1))


In [None]:
pi = backward_simplex(approx.mean.eval()[4*K:5*K-1])
approx.mean.eval().shape, approx.mean.eval()[:4*K], pi

In [None]:
plt.scatter(data[:,0], data[:,1], marker='.', alpha=0.1)
m = approx.mean.eval()
pi = backward_simplex(m[4*K:5*K-1])
for k in range(K):
    plt.plot(m[[2*k, 2*K+2*k]], m[[2*k+1, 2*K+2*k+1]], alpha=pi[k]/np.max(pi))
plt.show()
plt.bar(x=range(K), height=backward_simplex(m[4*K:5*K-1]))