In [3]:
import jax.numpy as jnp
from genjax import ExactDensity, Pytree
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
    
""" Distribution Utilities """ 

unicat = lambda x: jnp.ones(len(x)) / len(x)
normalize = lambda x: x / jnp.sum(x)

""" Define Distributions """

@Pytree.dataclass
class LabeledCategorical(ExactDensity):
    def sample(self, key, probs, labels, **kwargs):
        cat = tfd.Categorical(probs=normalize(probs))
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, probs, labels, **kwargs):
        w = jnp.log(jnp.sum(normalize(probs) * (labels==v)))
        return w

@Pytree.dataclass
class UniformCategorical(ExactDensity):
    def sample(self, key, labels, **kwargs):
        cat = tfd.Categorical(probs=jnp.ones(len(labels)) / len(labels))
        cat_index = cat.sample(seed=key)
        return labels[cat_index]

    def logpdf(self, v, labels, **kwargs):
        probs = jnp.ones(len(labels)) / len(labels)
        logpdf = jnp.log(probs)
        w = logpdf[0]
        return w
    
labcat = LabeledCategorical()
uniformcat = UniformCategorical()

In [13]:
import jax
import chex
import jax.numpy as jnp
import genjax
from genjax import gen, Mask
from genjax import ChoiceMapBuilder as C
genjax.pretty()

# @chex.dataclass(frozen=True)
# class FireflyState:
#     sigma_x : float
#     sigma_y : float
#     motion_var_x : float
#     motion_var_y : float
#     scene_size : int = None
#     max_fireflies_arr : jnp.ndarray
#     n_fireflies : int = None
#     masked_fireflies : Mask = None


@gen
def sample_firefly_init(scene_size):
    spatial_range = jnp.arange(1, scene_size)
    init_x = labcat(unicat(spatial_range), spatial_range) @ f"init_x"
    init_y = labcat(unicat(spatial_range), spatial_range) @ f"init_y"
    blink_rate = genjax.uniform(0.1,0.25) @ f"blink_rate"
    return (init_x, init_y, blink_rate)

sample_firefly_init_masked = genjax.MaskCombinator(sample_firefly_init)
sample_firefly_init_masked_map = jax.vmap(sample_firefly_init_masked.simulate, in_axes=(0, None))

@gen
def initialize_fireflies(scene_size, max_fireflies, steps_arr):
    max_fireflies_arr = jnp.arange(1, max_fireflies)
    n_fireflies = labcat(unicat(max_fireflies_arr), max_fireflies_arr) @ "n_fireflies"
    masked_fireflies = sample_firefly_init_masked_map(
        max_fireflies_arr < n_fireflies,
        scene_size
    ) @ "fireflies" 
    
    return masked_fireflies

In [14]:
key = jax.random.PRNGKey(423234)
initialize_fireflies.simulate(key, (jnp.zeros(64), jnp.arange(1,4), None))

ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: traced array with shape int32[3]
It arose in the jnp.arange argument 'stop'

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError

In [None]:
jnp.arange(1,4)

In [None]:
key = jax.random.PRNGKey(3746478)
tr = yoni.mask().vmap().simulate(key, (jnp.array([False, False, False, False]),))

In [None]:
from genjax import Diff, NoChange, UpdateProblem
from genjax import UpdateProblemBuilder as U


Diff(3, NoChange)