In [3]:
import jax
import jax.numpy as jnp
import genjax
from genjax import gen, Mask
from genjax import ChoiceMapBuilder as C
genjax.pretty()

def masked_scan_combinator(step, **scan_kwargs):
    """
    Given a generative function `step` so that `step.scan(n=N)` is valid,
    return a generative function accepting an input
    `(initial_state, masked_input_values_array)` and returning a pair
    `(masked_final_state, masked_returnvalue_sequence)`.
    This operates similarly to `step.scan`, but the input values can be masked.
    """
    mstep = step.mask().dimap(
        pre=lambda masked_state, masked_inval: (
            jnp.logical_and(masked_state.flag, masked_inval.flag),
            masked_state.value,
            masked_inval.value
        ),
        post=lambda args, masked_retval: (
            Mask(masked_retval.flag, masked_retval.value[0]),
            Mask(masked_retval.flag, masked_retval.value[1])
        )
    )

    # This should be given a pair (
    #     Mask(True, initial_state), 
    #     Mask(bools_indicating_active, input_vals)
    # ).
    # It will output a pair (masked_final_state, masked_returnvalue_sequence).
    scanned = mstep.scan(**scan_kwargs)

    scanned_nice = scanned.dimap(
        pre=lambda initial_state, masked_input_values: (
            Mask(True, initial_state),
            Mask(masked_input_values.flag, masked_input_values.value)
        ),
        post=lambda args, retval: retval
    )

    return scanned_nice

### Usage example ###
@gen
def step(state, sigma):
    new_x = genjax.normal(state, sigma) @ "x"
    return (new_x, new_x + 1)

trace = masked_scan_combinator(step, n=10).simulate(
    jax.random.PRNGKey(20), (
        2.,
        Mask(jnp.arange(10) < 5, jnp.arange(10, dtype=float))
    )
)

In [5]:
trace, weight = masked_scan_combinator(step, n=10).importance(
    jax.random.PRNGKey(20), C.n(), (
        2.,
        Mask(jnp.arange(10) < 5, jnp.arange(10, dtype=float))
    )
)

TypeError: Method[1m[35m genjax._src.generative_functions.combinators.scan.ScanCombinator.update_change_target()[0m parameter [1m[34margdiffs[0m=[1m[31m(Mask(...), Mask(...))[0m violates type hint [1m[32mtyping.Annotated[tuple, Is[lambda v: Diff.static_check_tree_diff(v)]][0m, as [1m[33mtuple [0m[1m[31m(Mask(...), Mask(...))[0m violates validator Is[lambda v: Diff.static_check_tree_diff(v)]:
    False == Is[lambda v: Diff.static_check_tree_diff(v)].