In [2]:
import jax
import jax.numpy as jnp
import genjax
from genjax import gen, Mask
from genjax import ChoiceMapBuilder as C
genjax.pretty()

def masked_scan_combinator(step, **scan_kwargs):
    """
    Given a generative function `step` so that `step.scan(n=N)` is valid,
    return a generative function accepting an input
    `(initial_state, masked_input_values_array)` and returning a pair
    `(masked_final_state, masked_returnvalue_sequence)`.
    This operates similarly to `step.scan`, but the input values can be masked.
    """
    mstep = step.mask().dimap(
        pre=lambda masked_state, masked_inval: (
            jnp.logical_and(masked_state.flag, masked_inval.flag),
            masked_state.value,
            masked_inval.value
        ),
        post=lambda args, masked_retval: (
            Mask(masked_retval.flag, masked_retval.value[0]),
            Mask(masked_retval.flag, masked_retval.value[1])
        )
    )

    # This should be given a pair (
    #     Mask(True, initial_state), 
    #     Mask(bools_indicating_active, input_vals)
    # ).
    # It will output a pair (masked_final_state, masked_returnvalue_sequence).
    scanned = mstep.scan(**scan_kwargs)

    scanned_nice = scanned.dimap(
        pre=lambda initial_state, masked_input_values: (
            Mask(True, initial_state),
            Mask(masked_input_values.flag, masked_input_values.value)
        ),
        post=lambda args, retval: retval
    )

    return scanned_nice

### Usage example ###
@gen
def step(state, sigma):
    new_x = genjax.normal(state, sigma) @ "x"
    return (new_x, new_x + 1)

trace = masked_scan_combinator(step, n=10).simulate(
    jax.random.PRNGKey(20), (
        2.,
        Mask(jnp.arange(10) < 5, jnp.arange(10, dtype=float))
    )
)

In [3]:
trace = step.scan(n=0).simulate(jax.random.PRNGKey(20), (2., jnp.arange(0, dtype=float)))

In [4]:
trace.get_score()

In [5]:
step.scan(n=0).importance(jax.random.PRNGKey(20), C.n(), (2., jnp.arange(0, dtype=float)))
None

In [8]:
step.scan(n=2).importance(jax.random.PRNGKey(20), trace.get_choices(), (2., jnp.arange(2, dtype=float)))
None

IndexError: index is out of bounds for axis 0 with size 0

In [12]:
@gen
def foo(length_const):
    x = step.scan(n=length_const.const)(2., jnp.arange(length_const.const, dtype=float)) @ "x"
    y = genjax.normal(0., 1.) @ "y"
    return y

In [15]:
trace = foo.simulate(jax.random.PRNGKey(20), (genjax.Pytree.const(0),))

In [19]:
trace2 = foo.importance(jax.random.PRNGKey(20), C["y"].set(2.), (genjax.Pytree.const(2),))