In [58]:
import pytensor as pt
import pytensor.tensor as ptt
import numpy as np
import pymc as pm

In [57]:
w = ptt.constant([[0.1, 0.2, 0.3]])
s = pt.shared(np.zeros((1, 3)))

def one_step(prev_w, s):
    s = ptt.concatenate([s, prev_w], axis=0)
    return ptt.zeros(1)

values, updates = pt.scan(one_step, sequences=[w, s], n_steps=3)

TypeError: one_step() missing 1 required positional argument: 's'

In [68]:
dist = pm.Categorical.dist(p=ptt.constant([[0, 0.4, 0.6], [0.2, 0., 0.8], [0.4, 0.6, 0]]), shape=3)

In [80]:
influencers = pm.draw(dist, 4).T
influencers

array([[2, 2, 1, 1],
       [2, 0, 2, 0],
       [0, 1, 1, 1]])

In [123]:
noise = pm.draw(pm.Normal.dist(mu=0, sigma=1, shape=(3,2,4)), 1) * 0.5
noise

array([[[-0.93646121, -0.97899491,  0.05444104, -0.02495986],
        [ 0.42654827, -0.88293535, -0.26931151,  0.22898184]],

       [[ 0.05631156, -0.12275639, -0.66190417, -0.13972772],
        [-0.74398082,  0.45737705, -0.00349667,  0.25716846]],

       [[ 0.05765788,  0.59071399, -0.40708371, -0.05506085],
        [-0.42575659, -0.30512063,  0.26339662,  1.00548   ]]])

In [253]:
num_time_steps = 4
mixture_weights = ptt.constant([[0, 0.4, 0.6], [0.2, 0., 0.8], [0.4, 0.6, 0]])
                               
# We sample the influencers in each time step using the mixture weights
influencers_dist = pm.Categorical.dist(p=mixture_weights, shape=3)
influencers = pm.draw(influencers_dist, num_time_steps) # t x s
influencers_values

ValueError: shape mismatch: objects cannot be broadcast to a single shape.  Mismatch is between arg 0 with shape (3,) and arg 1 with shape (3, 4).
Apply node that caused the error: categorical_rv{0, (1,), int64, True}(RandomGeneratorSharedVariable(<Generator(PCG64) at 0x7FB5B2DDCF20>), TensorConstant{[3 4]}, TensorConstant{4}, TensorConstant{[[0.  0.4 .. 0.6 0. ]]})
Toposort index: 0
Inputs types: [RandomGeneratorType, TensorType(int64, (2,)), TensorType(int64, ()), TensorType(float64, (3, 3))]
Inputs shapes: ['No shapes', (2,), (), (3, 3)]
Inputs strides: ['No strides', (8,), (), (24, 8)]
Inputs values: [Generator(PCG64) at 0x7FB5B2DDCF20, array([3, 4]), array(4), 'not shown']
Outputs clients: [['output'], ['output']]

HINT: Re-running with most PyTensor optimizations disabled could provide a back-trace showing when this node was created. This can be done by setting the PyTensor flag 'optimizer=fast_compile'. If that does not work, PyTensor optimizations can be disabled with 'optimizer=None'.
HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

In [248]:
initial_mean = 0
sigma = 1
prior_dist = pm.Normal.dist(mu=initial_mean, sigma=1, shape=(3,2))
prev_mask = np.array([0,1,1,1])
subject_mask = np.ones(4)
sample_values = np.ones((3, 2, num_time_steps)) * pm.draw(prior_dist, 1)[..., None] * (1 - prev_mask[None, :]) * subject_mask[None, :]
sample_values

array([[[-1.20977353, -0.        , -0.        , -0.        ],
        [ 0.73473298,  0.        ,  0.        ,  0.        ]],

       [[-0.83286284, -0.        , -0.        , -0.        ],
        [-0.06210496, -0.        , -0.        , -0.        ]],

       [[-1.03320404, -0.        , -0.        , -0.        ],
        [ 0.48600279,  0.        ,  0.        ,  0.        ]]])

In [251]:
sample = ptt.tensor3("sample")
mask = ptt.vector("mask")
influencers = ptt.imatrix("influencers")
coordination = ptt.vector("coordination")

def sample_from_mixture(sample, mask, influencers, coordination, prev_val):
    # For time steps out of the component's scale, we just repeat the previous sampled values from all 
    # individuals
    return sample + (prev_val[influencers, :] * coordination + prev_val * (1 - coordination)) * mask + prev_val * (1 - mask)

# res, upd=pt.scan(fn=sample_from_mixture, outputs_info=ptt.zeros_like(sample[-1]), sequences=[sample, mask], non_sequences=[mixture_weights])
res, upd=pt.scan(fn=sample_from_mixture, outputs_info=ptt.zeros_like(sample[-1]), sequences=[sample, mask, influencers, coordination])

f=pt.function(inputs=[sample, mask, influencers, coordination], outputs=res)

coordination_values = np.array([1, 0, 0.5, 0.6])
f(sample_values.swapaxes(0,-1).swapaxes(1,2), subject_mask, influencers_values.astype(np.int32), coordination_values).swapaxes(0,1).swapaxes(1,2)
# + Noise


array([[[-1.20977353, -1.20977353, -1.12148878, -1.12148878],
        [ 0.73473298,  0.73473298,  0.61036789,  0.61036789]],

       [[-0.83286284, -0.83286284, -0.93303344, -1.04610664],
        [-0.06210496, -0.06210496,  0.21194892,  0.4510003 ]],

       [[-1.03320404, -1.03320404, -1.12148878, -1.12148878],
        [ 0.48600279,  0.48600279,  0.61036789,  0.61036789]]])

-1.121488785

In [259]:
ptt.random.categorical(p=mixture_weights, size=(10,3)).eval()

array([[2, 2, 0],
       [1, 0, 0],
       [2, 2, 1],
       [2, 2, 1],
       [2, 0, 0],
       [2, 0, 1],
       [2, 2, 1],
       [2, 2, 0],
       [1, 2, 1],
       [1, 2, 0]])