In [79]:
import jax
import jax.numpy as jnp
import genjax
from genjax import ChoiceMapBuilder as C
import proposals

import sys
sys.path.append("../src/")
from maskcombinator_model import *

import matplotlib.pyplot as plt

genjax.pretty()

In [7]:
class JaxKey():
    def __init__(self, seed):
        self.seed = seed
        self.key = jax.random.PRNGKey(seed)

    def __call__(self, n_keys=1):
        if n_keys == 1:
            _, subkey = jax.random.split(self.key)
            self.key = subkey
            return self.key
        else:
            keys = jax.random.split(self.key, n_keys)
            self.key = keys[-1]
            return keys
        
seed = 123
keygen = JaxKey(seed)

Generative model functions:

assess -- log probability of an observation under a trace (requires complete choicemap)

importance -- same as assess but without full choicemap (importance sample with proposal to get full choicemap, 
                                                          get model logprobs using assess)
                                                          
update -- convenience to save computation when trace isn't changing

What I want are the following:

1. initialization proposal, initialization prior
2. dynamics proposal, dynamics prior
3. observation / likelihood function

Then, at each step (after initialization), I want to:

1. compute $Q(x_t; x_{t-1}, y_t)$ by calling `.importance` on the dynamics proposal
2. compute $P(x_t | x_{t-1})$ by using the choices from our proposal and calling `.assess` on the prior (model dynamics)
3. compute $P(y_t | x_t)$ using observation model
4. Calculate weights as: $\frac{P(y_t | x_t) P(x_t | x_{t-1})}{Q(x_t; x_{t-1}, {y_t})}$
5. Resample particles by drawing from `Categorical(weights)` M times


Let's try to take apart our existing model and re-factor it to make it easier to use those components

### 1. Initialization

Our initialization model will sample N and initalize fireflies using a mask over the max

Our proposal can take the first observation, and initalize in the vicinity of any observed blinks

In [None]:
@gen
def init_firefly_at_random():
    init_x = genjax.uniform(1., SCENE_SIZE.astype(jnp.float32)) @ "x"
    init_y = genjax.uniform(1., SCENE_SIZE.astype(jnp.float32)) @ "y"

    vx = genjax.truncated_normal(0., .5, MIN_VELOCITY, MAX_VELOCITY) @ "vx"
    vy = genjax.truncated_normal(0., .5, MIN_VELOCITY, MAX_VELOCITY) @ "vy"

    blink_rate = genjax.normal(0.1, 0.01) @ "blink_rate"
    blinking = False
    #state_duration = jax.lax.select(True, 0, 0)

    firefly = {
        "x": init_x,
        "y": init_y,
        "vx": vx,
        "vy": vy,
        "blink_rate": blink_rate,
        "blinking": blinking,
        #"state_duration": state_duration
    }

    return firefly

@gen
def init_firefly_at_loc(obs_x, obs_y):
    """
    Takes in observed locations and initializes 
    fireflies if the locations are in bounds
    """
    _scene_size = SCENE_SIZE.astype(jnp.float32)
    vx = genjax.truncated_normal(0., .5, MIN_VELOCITY, MAX_VELOCITY) @ "vx"
    vy = genjax.truncated_normal(0., .5, MIN_VELOCITY, MAX_VELOCITY) @ "vy"

    # If obs_x / obs_y are < 0, sample them randomly
    is_valid_x = obs_x > 0.
    is_valid_y = obs_y > 0.
    x = genjax.normal.or_else(genjax.uniform)(is_valid_x, (obs_x - vx, 0.1), (0., _scene_size)) @ "x"
    y = genjax.normal.or_else(genjax.uniform)(is_valid_y, (obs_y - vy, 0.1), (0., _scene_size)) @ "y"
    
    blink_rate = genjax.normal(0.1, 0.01) @ "blink_rate"
    blinking = False

    firefly = {
        "x": x,
        "y": y,
        "vx": vx,
        "vy": vy,
        "blink_rate": blink_rate,
        "blinking": blinking,
    }

    return firefly

@gen 
def model_init_fireflies(max_fireflies):
    n_fireflies = labcat(unicat(max_fireflies), max_fireflies) @ "n_fireflies"
    masks = jnp.array(max_fireflies <= n_fireflies)
    # Will produce traces with the form: trace["init", :, f"{variable}"]
    init_states = init_firefly_at_random.mask().vmap(in_axes=(0))(masks) @ "init"
    return init_states


@gen 
def proposal_init_fireflies(max_fireflies, x_obs, y_obs):
    """
    max_fireflies: jnp.array of the form [1, 2, 3, ..., max_fireflies]
    x_obs, y_obs: (max_fireflies,) array with observations. Valid observations 
          are anything inside the scene limits
    """
    # Get number valid observations
    num_valid_obs = jnp.sum(jnp.where(x_obs > -1, 1, 0))
    firefly_probs = jnp.where(max_fireflies < num_valid_obs, 0., 1/num_valid_obs)
    n_fireflies = labcat(unicat(firefly_probs), max_fireflies) @ "n_fireflies"
    masks = jnp.array(max_fireflies <= n_fireflies)
    jax.debug.print("m={m}", m=masks)
    # Our single firefly proposal function gets an observed x and y
    # and initializes at random if the values are < 0. 
    init_states = init_firefly_at_loc.mask().vmap(in_axes=(0, 0, 0))(masks, x_obs, y_obs) @ "init"
    return init_states

In [None]:
model_importance = jax.jit(multifirefly_model.importance)
model_update = jax.jit(multifirefly_model.update)

In [62]:
random_x = 12.
random_y = 30.
obs_x = 10.
obs_y = 1.
valid_x = obs_x > 0.
valid_y = obs_y > 0.
init_x_at = jax.lax.cond(valid_x, lambda: obs_x, lambda: random_x)
init_y_at = jax.lax.cond(valid_y, lambda: obs_y, lambda: random_y)