In [1]:
import pymc as pm
import numpy as np

### alpha - scalar

In [2]:
K = 4
alpha = 0.5
value = np.array([5, 4, 3, 2, 1]) / 15

pm.logp(pm.StickBreakingWeights.dist(alpha=alpha, K=K), value).eval()

array(1.51263013)

In [3]:
K = 4
alpha = 0.5
value = np.array([5, 4, 3, 2, 1]) / 15

pm.logp(pm.StickBreakingWeights_Batched.dist(alpha=alpha, K=K), value).eval()

array(1.51263013)

### alpha - 1D Array

In [None]:
pm.logp(pm.StickBreakingWeights.dist(alpha=alpha, K=K), value).eval()

In [39]:
K = 4
alpha = [0.5, 1, 2, 3]
value = np.array([5, 4, 3, 2, 1]) / 15

In [40]:
stick_break_batch = pm.logp(pm.StickBreakingWeights_Batched.dist(alpha=alpha, K=K), value).eval()
stick_break_batch

array([1.51263013, 2.93119375, 2.99573227, 1.9095425 ])

In [41]:
# pm.logp(pm.StickBreakingWeights.dist(alpha=alpha, K=K), value).eval()            #ERROR

vfun = np.vectorize(pm.StickBreakingWeights.dist)

In [42]:
stick_break = [pm.logp(stick, value).eval() for stick in vfun(alpha=alpha, K=K)]
stick_break

[array(1.51263013), array(2.93119375), array(2.99573227), array(1.9095425)]

In [43]:
np.allclose(stick_break, stick_break_batch)

True

### alpha - 2D Array

In [44]:
K = 4
alpha = np.arange(1, 13).reshape(3, 4)
value = np.array([5, 4, 3, 2, 1]) / 15

In [45]:
stick_break_batch = pm.logp(pm.StickBreakingWeights_Batched.dist(alpha=alpha, K=K), value).eval()
stick_break_batch

array([[  2.93119375,   2.99573227,   1.9095425 ,   0.35222059],
       [ -1.4632554 ,  -3.44201938,  -5.53346686,  -7.70739149],
       [ -9.94430955, -12.23091769, -14.55772717, -16.91773186]])

In [48]:
stick_break = [pm.logp(stick, value).eval() for stick in vfun(alpha=alpha, K=K).reshape(-1)]
stick_break

[array(2.93119375),
 array(2.99573227),
 array(1.9095425),
 array(0.35222059),
 array(-1.4632554),
 array(-3.44201938),
 array(-5.53346686),
 array(-7.70739149),
 array(-9.94430955),
 array(-12.23091769),
 array(-14.55772717),
 array(-16.91773186)]

In [50]:
np.allclose(stick_break, stick_break_batch.reshape(-1))

True

### alpha - 3D Array

In [None]:
K = 4
alpha = np.arange(1, 2*3*4+1).reshape(2, 3, 4)
value = np.array([5, 4, 3, 2, 1]) / 15

pm.logp(pm.StickBreakingWeights.dist(alpha=alpha, K=K), value).eval()

In [51]:
stick_break_batch = pm.logp(pm.StickBreakingWeights_Batched.dist(alpha=alpha, K=K), value).eval()
stick_break_batch

array([[  2.93119375,   2.99573227,   1.9095425 ,   0.35222059],
       [ -1.4632554 ,  -3.44201938,  -5.53346686,  -7.70739149],
       [ -9.94430955, -12.23091769, -14.55772717, -16.91773186]])

In [52]:
stick_break = [pm.logp(stick, value).eval() for stick in vfun(alpha=alpha, K=K).reshape(-1)]
stick_break

[array(2.93119375),
 array(2.99573227),
 array(1.9095425),
 array(0.35222059),
 array(-1.4632554),
 array(-3.44201938),
 array(-5.53346686),
 array(-7.70739149),
 array(-9.94430955),
 array(-12.23091769),
 array(-14.55772717),
 array(-16.91773186)]

In [53]:
np.allclose(stick_break, stick_break_batch.reshape(-1))

True