In [1]:
import pymc as pm
import numpy as np
import pymc_experimental as pmx

I follow **Example 9** from [here](https://projecteuclid.org/ebooks/nsf-cbms-regional-conference-series-in-probability-and-statistics/Nonparametric-Bayesian-Inference/Chapter/Chapter-3-Dirichlet-Process/10.1214/cbms/1362163748) (displayed below) and attempt to also replicate the results shown in Figure 3.2. They assume that data are drawn i.i.d. from $\mathcal{N}(2, 4)$ but assume a base distribution $G_0 = \mathcal{N}(0, 1)$.

<img src="dp-example-9.png" width=500>

In [2]:
alpha = 5.0  # concentration parameter
K = 19  # truncation parameter

rng = np.random.default_rng(seed=34)
obs = rng.normal(2.0, 2.0, size=50)

In [3]:
with pm.Model() as model:
    base_dist = pm.Normal("base_dist", 0.0, 1.0, shape=(K + 1,))
    sbw, atoms = pmx.dp.DirichletProcess("dp", alpha, base_dist, K, observed=obs)

    trace = pm.sample()

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>NUTS: [base_dist, sbw]
>BinaryGibbsMetropolis: [idx]
>CategoricalGibbsMetropolis: [atom_selection]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 7 seconds.
Chain <xarray.DataArray 'chain' ()>
array(0)
Coordinates:
    chain    int64 0 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(1)
Coordinates:
    chain    int64 1 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(2)
Coordinates:
    chain    int64 2 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.
Chain <xarray.DataArray 'chain' ()>
array(3)
Coordinates:
    chain    int64 3 reached the maximum tree depth. Increase `max_treedepth`, increase `target_accept` or reparameterize.


In [9]:
x_plot = np.linspace(-4, 8, num=1001)

In [21]:
dirac = np.less.outer(x_plot, trace.posterior["atoms"].values[0, 0])

In [26]:
trace.posterior["sbw"].values[0]

(1000, 20)

In [22]:
dirac

(1001, 20)