In [1]:
import jax.numpy as jnp
from genjax import ExactDensity, Pytree
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
    
""" Distribution Utilities """ 

unicat = lambda x: jnp.ones(len(x)) / len(x)
normalize = lambda x: x / jnp.sum(x)

""" Define Distributions """

@Pytree.dataclass
class LabeledCategorical(ExactDensity):
    def sample(self, key, probs, labels, **kwargs):
        cat = tfd.Categorical(probs=normalize(probs))
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, probs, labels, **kwargs):
        w = jnp.log(jnp.sum(normalize(probs) * (labels==v)))
        return w

@Pytree.dataclass
class UniformCategorical(ExactDensity):
    def sample(self, key, labels, **kwargs):
        cat = tfd.Categorical(probs=jnp.ones(len(labels)) / len(labels))
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, labels, **kwargs):
        probs = jnp.ones(len(labels)) / len(labels)
        logpdf = jnp.log(probs)
        w = logpdf[0]
        return w
    
labcat = LabeledCategorical()
uniformcat = UniformCategorical()

In [74]:
import jax
import chex
import jax.numpy as jnp
import genjax
from genjax import gen, Mask
from genjax import ChoiceMapBuilder as C
genjax.pretty()

# @chex.dataclass(frozen=True)
# class FireflyState:
#     sigma_x : float
#     sigma_y : float
#     motion_var_x : float
#     motion_var_y : float
#     scene_size : int = None
#     max_fireflies_arr : jnp.ndarray
#     n_fireflies : int = None
#     masked_fireflies : Mask = None

@gen
def sample_firefly_init(spatial_range):
    init_x = labcat(unicat(spatial_range), spatial_range) @ f"init_x"
    init_y = labcat(unicat(spatial_range), spatial_range) @ f"init_y"
    blink_rate = genjax.uniform(0.1, 0.25) @ f"blink_rate"
    return (init_x, init_y, blink_rate)

sample_firefly_init_masked = genjax.mask_combinator(sample_firefly_init)
sample_firefly_init_masked_map = sample_firefly_init_masked.vmap(in_axes=(0, None))

@gen
def initialize_fireflies(scene_size_arr, max_fireflies_arr):
    n_fireflies = labcat(unicat(max_fireflies_arr), max_fireflies_arr) @ "n_fireflies"

    jax.debug.print("{n}, {x}", n=n_fireflies, x=max_fireflies_arr < n_fireflies)
    masked_fireflies = sample_firefly_init_masked_map(
        max_fireflies_arr < n_fireflies,
        scene_size_arr
    ) @ "fireflies" 
    
    return masked_fireflies

In [75]:
key, subkeys = jax.random.split(jax.random.PRNGKey(0), 2)
tr = sample_firefly_init_masked_map.simulate(key, ([True, False], jnp.arange(10),))
tr.get_sample()

AttributeError: 'bool' object has no attribute 'shape'

In [43]:
key = jax.random.PRNGKey(0)
max_fireflies = jnp.arange(1, 5)
scene_size = jnp.arange(1, 65)
initialize_fireflies.simulate(key, (scene_size, max_fireflies,))

TypeError: Method[1m[35m genjax._src.generative_functions.combinators.mask_combinator.MaskCombinator.simulate()[0m parameter [1m[34mkey[0m=[1m[31m"Traced<ShapedArray(bool[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(bool..."[0m violates type hint [1m[32mtyping.Union[jaxtyping.Key[Array, ''], jaxtyping.UInt32[Array, '2']][0m, as [1m[33m<class "jax._src.interpreters.batching.BatchTracer"> [0m[1m[31m"Traced<ShapedArray(bool[])>with<BatchTrace(level=2/0)> with
  val = Traced<ShapedArray(bool..."[0m not [1m[32m<class "jaxtyping.UInt32[Array, '2']"> or <class "jaxtyping.Key[Array, '']">[0m.