In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import jax
import time
import jax.numpy as jnp
import numpy as np
import genjax
from genjax import gen, Mask, Diff
from genjax import ChoiceMapBuilder as C
from genjax import UpdateProblemBuilder as U
from genjax import SelectionBuilder as S
import jax.tree_util as jtu
genjax.pretty()
import copy
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from IPython.display import HTML as HTML_Display
from matplotlib.animation import FuncAnimation

from maskcombinator_model import *
from distributions import *
from config import *
from render import *

In [None]:
def and_new_keys(key, N):
    key, tmp_key = jax.random.split(key)
    new_keys = jax.random.split(tmp_key, N)
    return key, new_keys

def summary_stats(weights):
    print('mean={mean}, std={std}, min={min}, max={max}'.format(
        mean=jnp.mean(weights), std=jnp.std(weights),
        min=jnp.min(weights), max=jnp.max(weights)))

In [48]:
def make_step_choicemap(observation_chm, step_index):
    pixels = observation_chm["steps", step_index, "observations", "pixels"].value
    return C["steps", step_index, "observations", "pixels"].set(pixels)

def make_masked_combinator_step_update_problem(init_carry, observations, step, num_steps):
    argdiffs = (
        Diff.no_change(init_carry),
        Diff.unknown_change(jnp.arange(num_steps) <= step),
    )
    chm = make_step_choicemap(observations, step)
    return U.g(argdiffs, chm)

In [49]:
def make_sequential_monte_carlo_sampler(masked_model, observations_dict, num_particles, num_steps=None):
    if not num_steps:
        num_steps = len(next(iter(observations_dict.values())))
    
    def sequential_monte_carlo_sampler(key, step_init_args):
        key, init_keys = and_new_keys(key, num_particles)
        
        init_chm = make_step_choicemap(observations_dict, 0)
        model_init_args = (step_init_args, jnp.arange(num_steps) <= 0)
        init_particles, init_weights = jax.vmap(masked_model.importance, in_axes=(0, None, None))(
            init_keys, init_chm, model_init_args)

        def scan_fn(smc_scan_state, scan_input):
            (prev_weights, step_particles) = smc_scan_state
            key, time_step = scan_input

            key, resample_key = jax.random.split(key)
            # With this we recover SIS performance.
            # parents = jax.random.permutation(resample_key, jnp.arange(num_particles))
            parents = jax.random.categorical(resample_key, prev_weights, shape=(num_particles,))

            step_particles = jax.tree.map(lambda x: x[parents], step_particles)

            key, step_keys = and_new_keys(key, num_particles)
            update_problem = make_masked_combinator_step_update_problem(
                step_init_args, observations_dict, time_step, num_steps)
        
            # We can overwrite the weights since we've resampled unconditionally. If / when we 
            # modify the code to sample based on ESS, we'll have to be more careful.
            step_particles, step_weights, _, _ = jax.vmap(masked_model.update, in_axes=(0, 0, None))(
                step_keys, step_particles, update_problem)
        
            return (step_weights, step_particles), None
            
        scan_keys = jax.random.split(key, num_steps - 1)

        (_, final_particles), _ = jax.lax.scan(
            scan_fn, (init_weights, init_particles), (scan_keys, jnp.arange(1, num_steps)))
        
        return final_particles

    return sequential_monte_carlo_sampler


In [50]:
max_fireflies = jnp.arange(1, 5)
key = jax.random.PRNGKey(101)
key, subkey = jax.random.split(key)
constraints = C.d({"n_fireflies": 2})
run_until = jnp.arange(TIME_STEPS) < TIME_STEPS 
trace, weight = jax.jit(multifirefly_model.importance)(subkey, constraints, (max_fireflies, run_until,))
gt_choices = trace.get_sample()

In [51]:
print("Ground Truth Trace: ", gt_choices)

Ground Truth Trace:  XorChm(...)


In [52]:
masked_model = multifirefly_model
num_particles = 100
step_init_args = max_fireflies
smc_sampler = make_sequential_monte_carlo_sampler(masked_model, gt_choices, num_particles, num_steps=TIME_STEPS)
key, smc_key = jax.random.split(key)
# smc_particles = jax.jit(smc_sampler)(smc_key, step_init_args)

In [53]:
gt_choices

In [54]:
model_init_args = (max_fireflies, jnp.arange(TIME_STEPS) <= 0)
key, init_keys = and_new_keys(key, num_particles)
init_chm = make_step_choicemap(gt_choices, 0)
init_particles, init_weights = jax.vmap(masked_model.importance, in_axes=(0, None, None))(
            init_keys, init_chm, model_init_args)

In [57]:
key, step_keys = and_new_keys(key, num_particles)

update_problem = make_masked_combinator_step_update_problem(
                step_init_args, gt_choices, 1, TIME_STEPS)
step_particles, step_weights, _, _ = jax.vmap(masked_model.update, in_axes=(0, 0, None))(
                step_keys, init_particles, update_problem)

TypeError: true_fun and false_fun output must have same type structure, got PyTreeDef((CustomNode(MaskTrace[StructStaticMetadata(child_field_names=['mask_combinator', 'inner', 'check'], static_fields={})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_and_observe at 0x29dc60220>})], [()])])]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_and_observe at 0x29dc60220>})], [()])]), ((CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *),), (CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *), CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('dynamics',), ('observations',)]})], []), [CustomNode(VmapTrace[StructStaticMetadata(child_field_names=['gen_fn', 'inner', 'args', 'retval', 'score'], static_fields={})], [CustomNode(VmapCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={'in_axes': (0, 0)})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])])])]), CustomNode(MaskTrace[StructStaticMetadata(child_field_names=['mask_combinator', 'inner', 'check'], static_fields={})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])])]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])]), ({'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *},), {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}, CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('x',), ('y',), ('vx',), ('vy',), ('blink',)]})], []), [CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bb6660>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bb6700>})], [()])]), (*,), *, *])], *]), *]), (*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}), CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function observe_fireflies at 0x29dc60180>})], [()])]), (*, *, *, *), *, CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('pixels',)]})], []), [CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *])], *])], *]), *]), *, CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), (CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), {'blink_rate': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'blinking': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'state_duration': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vx': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vy': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'x': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'y': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])])}]), CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]))]), CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'dynamics'})], [CustomNode(IdxChm[StructStaticMetadata(child_field_names=['addr', 'c'], static_fields={})], [*, CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'x'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'y'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'vx'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'vy'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'blink'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])])])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'observations'})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'pixels'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, *])])])])])]))) and PyTreeDef((CustomNode(MaskTrace[StructStaticMetadata(child_field_names=['mask_combinator', 'inner', 'check'], static_fields={})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_and_observe at 0x29dc60220>})], [()])])]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_and_observe at 0x29dc60220>})], [()])]), ((CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *),), (CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *), CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('dynamics',), ('observations',)]})], []), [CustomNode(VmapTrace[StructStaticMetadata(child_field_names=['gen_fn', 'inner', 'args', 'retval', 'score'], static_fields={})], [CustomNode(VmapCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={'in_axes': (0, 0)})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])])])]), CustomNode(MaskTrace[StructStaticMetadata(child_field_names=['mask_combinator', 'inner', 'check'], static_fields={})], [CustomNode(MaskCombinator[StructStaticMetadata(child_field_names=['gen_fn'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])])]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function step_firefly at 0x29dc61300>})], [()])]), ({'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *},), {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}, CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('x',), ('y',), ('vx',), ('vy',), ('blink',)]})], []), [CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *]), CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bb6660>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bb6700>})], [()])]), (*,), *, *])], *]), *]), (CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), {'blink_rate': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'blinking': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'state_duration': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vx': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vy': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'x': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'y': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])])}), CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [*, {'blink_rate': *, 'blinking': *, 'state_duration': *, 'vx': *, 'vy': *, 'x': *, 'y': *}]), *]), CustomNode(StaticTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'retval', 'addresses', 'subtraces', 'score'], static_fields={})], [CustomNode(StaticGenerativeFunction[StructStaticMetadata(child_field_names=['source'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function observe_fireflies at 0x29dc60180>})], [()])]), (*, *, *, *), *, CustomNode(AddressVisitor[StructStaticMetadata(child_field_names=[], static_fields={'visited': [('pixels',)]})], []), [CustomNode(DistributionTrace[StructStaticMetadata(child_field_names=['gen_fn', 'args', 'value', 'score'], static_fields={})], [CustomNode(ExactDensityFromCallables[StructStaticMetadata(child_field_names=['sampler', 'logpdf_evaluator'], static_fields={})], [CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.sampler at 0x160bcc860>})], [()]), CustomNode(Closure[StructStaticMetadata(child_field_names=['dyn_args'], static_fields={'fn': <function tfp_distribution.<locals>.logpdf at 0x160bcc900>})], [()])]), (*, *, *, *), *, *])], *])], *]), *]), *, CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), (CustomNode(Mask[StructStaticMetadata(child_field_names=['flag', 'value'], static_fields={})], [CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), {'blink_rate': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'blinking': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'state_duration': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vx': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'vy': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'x': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]), 'y': CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_NoChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])])}]), CustomNode(Diff[StructStaticMetadata(child_field_names=['primal', 'tangent'], static_fields={})], [*, CustomNode(_UnknownChange[StructStaticMetadata(child_field_names=[], static_fields={})], [])]))]), CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'dynamics'})], [CustomNode(IdxChm[StructStaticMetadata(child_field_names=['addr', 'c'], static_fields={})], [*, CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(XorChm[StructStaticMetadata(child_field_names=['c1', 'c2'], static_fields={})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'x'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'y'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'vx'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'vy'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'blink'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(EmptyProblem[StructStaticMetadata(child_field_names=[], static_fields={})], [])])])])])])])]), CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'observations'})], [CustomNode(StaticChm[StructStaticMetadata(child_field_names=['c'], static_fields={'addr': 'pixels'})], [CustomNode(ValueChm[StructStaticMetadata(child_field_names=['v'], static_fields={})], [CustomNode(MaskedProblem[StructStaticMetadata(child_field_names=['flag', 'problem'], static_fields={})], [*, *])])])])])]))).