In [6]:
import pytensor.tensor as pt

import pymc as pm

In [61]:
with pm.Model() as m:
    X = pm.Normal("X", 0, 1, size=(100, 10))
    alpha = pm.Normal("alpha", 100, 10)
    beta = pm.Normal("beta", 0, 5, size=(10,))

    mu = alpha + X @ beta
    sigma = pm.Exponential("sigma", 1)
    y = pm.Normal("y", mu=mu, sigma=sigma)

    prior = pm.sample_prior_predictive()

Sampling: [X, alpha, beta, sigma, y]


In [62]:
draw = 123
true_params = np.r_[
    prior.prior.alpha.sel(chain=0, draw=draw).values, prior.prior.beta.sel(chain=0, draw=draw)
]
X_data = prior.prior.X.sel(chain=0, draw=draw).values
y_data = prior.prior.y.sel(chain=0, draw=draw).values

In [63]:
m_obs = pm.observe(pm.do(m, {X: X_data}), {"y": y_data})

In [168]:
Parameter = pt.tensor

draws = pt.tensor("draws", shape=(), dtype="int64")

with pm.Model() as guide_model:
    X = pm.Data("X", X_data)
    alpha_loc = Parameter("alpha_loc", shape=())
    alpha_scale = Parameter("alpha_scale", shape=())
    alpha_z = pm.Normal("alpha_z", mu=0, sigma=1, shape=(draws,))
    alpha = pm.Deterministic("alpha", alpha_loc + alpha_scale * alpha_z)

    beta_loc = Parameter("beta_loc", shape=(10,))
    beta_scale = Parameter("beta_scale", shape=(10,))
    beta_z = pm.Normal("beta_z", mu=0, sigma=1, shape=(draws, 10))
    beta = pm.Deterministic("beta", beta_loc + beta_scale * beta_z)

    mu = alpha + X @ beta

    sigma_loc = Parameter("sigma_loc", shape=())
    sigma_scale = Parameter("sigma_scale", shape=())
    sigma_z = pm.Normal("sigma_z", 0, 1, shape=(draws,))
    sigma = pm.Deterministic("sigma", pt.softplus(sigma_loc + sigma_scale * sigma_z))

# with pm.Model() as guide_model2:
#     n = 10 + 1 + 1
#     loc = Parameter("loc", shape=(n,))
#     chol_flat = Parameter("chol", shape=(n * n-1, ))
#     chol = pm.expand_packed_triangular(n, chol_flat)
#     latent_mvn = pm.MvNormal("latent_mvn", chol=chol)

#     pm.Deterministic("beta", latent_mvn[:10])
#     pm.Deterministic("alpha", latent_mvn[10])
#     pm.Deterministic("sigma", pm.math.exp(latent_mvn[11]))

In [169]:
params = [alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]

In [171]:
f_draw = pm.compile([*params, draws], guide_model.deterministics)

In [182]:
f_draw(**param_dict, draws=1)

[array([1.78150955]),
 array([[ 0.22971278,  0.4621461 ,  0.81535912,  0.62397751,  1.12162984,
          0.99310042, -0.04733258,  1.20791346,  0.61310399,  0.6248215 ]]),
 array([1.04612599])]

In [204]:
init_dict = m_obs.initial_point()
init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}
param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}

In [187]:
from pytensor.graph.replace import graph_replace, vectorize_graph

outputs = [m_obs.datalogp, m_obs.varlogp]
inputs = m_obs.value_vars
inputs_to_guide_rvs = {
    model_value_var: guide_model[rv.name]
    for rv, model_value_var in m_obs.rvs_to_values.items()
    if rv not in m_obs.observed_RVs
}
model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)
guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)

elbo_loss = (guide_logq - model_logp).mean()
d_loss = pt.grad(elbo_loss, params)

f_loss_dloss = pm.compile(params + [draws], [elbo_loss, *d_loss], trust_input=True)

In [207]:
learning_rate = 1e-4
for _ in range(100):
    loss, *grads = f_loss_dloss(**param_dict, draws=100)
    for (name, value), grad in zip(param_dict.items(), grads):
        param_dict[name] = value - learning_rate * grad
    print(loss)

-993.9173656892265
-926.3296229583198
-1005.6712508438249
-919.6431050083108
-968.0575921556582
-987.30412574489
-964.7096352135038
-983.4362084632743
-910.8310347546233
-946.7555282915881
-943.9616597508198
-963.8766211135949
-916.6993999462212
-950.0475547851752
-967.001546433409
-943.9433767873104
-946.3015434524876
-934.0032018513697
-947.6452058569878
-893.224934118183
-979.4405814988748
-937.9997192780036
-931.1970724735944
-959.3588230003159
-932.1233734322371
-940.2791556640857
-969.4679954671045
-954.6606993395337
-982.9304227234845
-935.3404389982638
-982.6885250322749
-964.4628736035113
-939.0580477804804
-955.1672719181267
-982.0467504680682
-992.8688427985264
-967.8588846826874
-966.7655668600194
-949.3323016540423
-934.3919364553586
-1028.4493525361129
-982.4944127954707
-931.9404059809432
-981.3845063690508
-930.1688452196312
-952.0305908505434
-1012.7969628343917
-937.6379090307294
-926.5273721862864
-981.6665090046366
-980.4287957334458
-946.1849036479391
-969.40756535

In [208]:
param_dict

{'alpha_loc': 0.6942395020077371,
 'alpha_scale': 0.5102778323184816,
 'beta_loc': array([0.49114368, 0.46379742, 0.52277764, 0.53227815, 0.48851862,
        0.50044113, 0.5441339 , 0.46643231, 0.47894475, 0.51122713]),
 'beta_scale': array([0.49807094, 0.49740464, 0.5005586 , 0.4991539 , 0.49749037,
        0.49825551, 0.49992182, 0.4981251 , 0.49745959, 0.49767778]),
 'sigma_loc': 4.2148859620931765,
 'sigma_scale': -5.964631972194452e-05}

In [209]:
true_params

array([83.43771778, -6.07738876, -3.3268889 ,  8.38393732, 10.77212434,
       -2.81776509,  0.46737085,  8.7204497 , -4.79822835, -3.47220908,
        8.76186526])

In [165]:
f_loss(**param_dict, draws=100)

array(273519.9558606)

In [166]:
params

[alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]

In [106]:
inputs_to_guide_inputs

{alpha: alpha, beta: beta, sigma_log__: sigma}

In [22]:
def compute_loss(m_obs, m_guide, beta=1.0):
    return -m_obs.datalogp + beta * (m_guide.logp() - m_obs.varlogp)

In [24]:
compute_loss(m_obs, guide_model).dprint()

Add [id A]
 ├─ Neg [id B]
 │  └─ Add [id C]
 │     ├─ Sum{axes=None} [id D] '__logp'
 │     │  └─ MakeVector{dtype='float64'} [id E]
 │     │     └─ Sum{axes=None} [id F]
 │     │        └─ Check{sigma > 0} [id G]
 │     │           ├─ Sub [id H]
 │     │           │  ├─ Sub [id I]
 │     │           │  │  ├─ Mul [id J]
 │     │           │  │  │  ├─ ExpandDims{axis=0} [id K]
 │     │           │  │  │  │  └─ -0.5 [id L]
 │     │           │  │  │  └─ Pow [id M]
 │     │           │  │  │     ├─ True_div [id N]
 │     │           │  │  │     │  ├─ Sub [id O]
 │     │           │  │  │     │  │  ├─ [122.65317 ... .32067026] [id P]
 │     │           │  │  │     │  │  └─ Add [id Q]
 │     │           │  │  │     │  │     ├─ ExpandDims{axis=0} [id R]
 │     │           │  │  │     │  │     │  └─ alpha [id S]
 │     │           │  │  │     │  │     └─ Squeeze{axis=1} [id T]
 │     │           │  │  │     │  │        └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id U]
 │     │           │  │  │   

<ipykernel.iostream.OutStream at 0x7fc2f638ea10>

In [None]:
## TODO:
# 1. Create hyperparameters for mean field approx (mu + sigma of normals)
# 2. Replace in the logp