In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../src/")
import jax.numpy as jnp
import jax
import jax.tree_util as jtu
import genjax
from genjax import gen, Target, smc
from genjax import ChoiceMapBuilder as C
from config import *
from IPython.display import HTML
genjax.pretty()

from maskcombinator_model import multifirefly_model
from utils import *

# Testing a different observation model

Make sure that the observation model is `get_observed_blinks` and not the pixel based model before running this notebook.

In [None]:
importance_jit = jax.jit(multifirefly_model.importance)

In [None]:
key = jax.random.PRNGKey(914)
key, subkey = jax.random.split(key)
max_fireflies = jnp.arange(1, 3)

constraint = C["n_fireflies"].set(jnp.int32(1))
time_mask = jnp.arange(TIME_STEPS) < TIME_STEPS # Number of time steps to simulate
tr, gt_w = importance_jit(subkey, constraint, (max_fireflies, time_mask,))
print(gt_w)
gt_chm = tr.get_sample()
gt_chm

In [None]:
def scatter_animation(observed_xs, observed_ys, gt_xs=None, gt_ys=None):
    """
    Basic scatter plot animation with moving points
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_xlim(0, SCENE_SIZE)
    ax.set_ylim(SCENE_SIZE, 0)
    ax.set_title('Scatter Plot Animation')
    ax.set_facecolor("black")
    
    # Initialize scatter plot
    gt_scatter = ax.scatter([], [], edgecolors='g', facecolors=None, s=200, alpha=0.25, animated=True)
    obs_scatter = ax.scatter([], [], c='red', s=200, animated=True)

    # Animation update function
    def update(frame):
        if gt_xs is not None:
            xs = [x for x in gt_xs[frame, :] if x > 0]
            ys = [y for y in gt_ys[frame, :] if y > 0]
            gt_scatter.set_offsets(np.column_stack([xs, ys]))
            
        xs = [x for x in observed_xs[frame, :] if x > 0]
        ys = [y for y in observed_ys[frame, :] if y > 0]
        # Update scatter plot data
        obs_scatter.set_offsets(np.column_stack([xs, ys]))

        return obs_scatter, gt_scatter
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, 
        update, 
        frames=len(observed_xs),  # Number of animation frames
        interval=100,  # Milliseconds between frames
        blit=True
    )
    
    return anim

In [None]:
observed_xs, observed_ys = get_observations(gt_chm)
gt_xs, gt_ys = get_gt_locations(gt_chm)

anim = scatter_animation(observed_xs, observed_ys)
HTML(anim.to_jshtml())

In [None]:
def set_observations(observed_xs, observed_ys, max_fireflies):
    """
    Creates a choicemap encoding observed locations of fireflies at specific times

    Args:
        observed_xs, observed_ys: jnp.array: (n, t) array of observed x and y locations of fireflies
        max_fireflies: int: the maximum number of fireflies in a scene
        steps: int: the maximum number of steps in a scene
    """

    steps, n = observed_xs.shape
    chm = C.n()
    vals_x = jnp.zeros((steps, max_fireflies), jnp.float32)
    vals_y = jnp.zeros((steps, max_fireflies), jnp.float32)
    blinks = jnp.zeros((steps, max_fireflies), jnp.bool)

    # fill in addresses
    chm_dict = {}
    for t in range(steps):
        observed_xs_t = observed_xs[t]
        observed_ys_t = observed_ys[t]
        # Shuffle the indices to avoid biasing the order of the fireflies
        # indices = jax.random.permutation(jax.random.PRNGKey(t), jnp.arange(n))
        # observed_xs_t = observed_xs_t[indices]
        # observed_ys_t = observed_ys_t[indices]
        for i in range(n):
            if observed_xs_t[i] > 0:
                obs_x = observed_xs_t[i]
                obs_y = observed_ys_t[i]
                
                vals_y = vals_x.at[t, i].set(obs_x)
                vals_x = vals_y.at[t, i].set(obs_y)
                blinks = blinks.at[t, i].set(jnp.bool(True))        
    
    chm = jax.vmap(lambda t: C["steps", t, "observations", :, "observed_xs"].set(vals_x[t]))(jnp.arange(steps))
    chm = chm | jax.vmap(lambda t: C["steps", t, "observations", :, "observed_ys"].set(vals_y[t]))(jnp.arange(steps))
    chm = jax.vmap(lambda t: C["steps", t, "dynamics", "blinking", :].set(blinks[t]))(jnp.arange(steps))
    return chm

In [None]:
observed_xs, observed_ys = get_observations(gt_chm)
observations = set_observations(observed_xs, observed_ys, max_fireflies[-1])
observations

In [None]:
for i in max_fireflies:
    print("Number of Fireflies: ", i)
    constraint = observations | C["n_fireflies"].set(jnp.int32(i))

    key, subkey = jax.random.split(key)
    tr, w = importance_jit(subkey, constraint, (max_fireflies, time_mask,))
    print("Importance Weight: ", w)

In [None]:
def get_observations(chm):
    observed_xs = chm["steps", :, "observations", "observed_xs"]
    observed_ys = chm["steps", :, "observations", "observed_ys"]

    observed_xs = observed_xs.value[:, :len(observed_xs.flag)]
    observed_ys = observed_ys.value[:, :len(observed_ys.flag)]
    return observed_xs, observed_ys


def get_gt_locations(chm):
    xs = chm["steps", :, "dynamics", :, "x"]
    ys = chm["steps", :, "dynamics", :, "y"]

    xs = jnp.where(xs.flag, xs.value, -1.)
    ys = jnp.where(ys.flag, ys.value, -1.)
    return xs, ys

In [None]:
def scatter_animation(observed_xs, observed_ys, gt_xs=None, gt_ys=None):
    """
    Basic scatter plot animation with moving points
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.set_xlim(0, SCENE_SIZE)
    ax.set_ylim(SCENE_SIZE, 0)
    ax.set_title('Scatter Plot Animation')
    ax.set_facecolor("black")
    
    # Initialize scatter plot
    gt_scatter = ax.scatter([], [], edgecolors='g', facecolors=None, s=200, alpha=0.25, animated=True)
    obs_scatter = ax.scatter([], [], c='red', s=200, animated=True)

    # Animation update function
    def update(frame):
        if gt_xs is not None:
            xs = [x for x in gt_xs[frame, :] if x > 0]
            ys = [y for y in gt_ys[frame, :] if y > 0]
            gt_scatter.set_offsets(np.column_stack([xs, ys]))
            
        xs = [x for x in observed_xs[frame, :] if x > 0]
        ys = [y for y in observed_ys[frame, :] if y > 0]
        # Update scatter plot data
        obs_scatter.set_offsets(np.column_stack([xs, ys]))

        return obs_scatter, gt_scatter
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, 
        update, 
        frames=len(observed_xs),  # Number of animation frames
        interval=100,  # Milliseconds between frames
        blit=True
    )
    
    return anim

observed_xs, observed_ys = get_observations(chm)
gt_xs, gt_ys = get_gt_locations(chm)

anim = scatter_animation(observed_xs, observed_ys, gt_xs, gt_ys)
HTML(anim.to_jshtml())

In [None]:
N = 20000
K = 100

def SIR(N, K, model, chm):
    @jax.jit
    def _inner(key, args):
        key, subkey = jax.random.split(key)
        traces, weights = jax.vmap(model.importance, in_axes=(0, None, None))(
            jax.random.split(key, N), chm, args
        )
        idxs = jax.vmap(jax.jit(genjax.categorical.simulate), in_axes=(0, None))(
            jax.random.split(subkey, K), (weights,)
        ).get_retval()
        samples = traces.get_sample()
        resampled_samples = jax.vmap(lambda idx: jtu.tree_map(lambda v: v[idx], samples))(
            idxs
        )
        return resampled_samples

    return _inner

In [None]:
def make_constraints(trace):
    observed_xs = trace["steps", "observations", "observed_xs"]
    observed_ys = trace["steps", "observations", "observed_ys"]
    constraints = C["steps", :, "observations", "observed_xs"].set(observed_xs) ^ \
            C["steps", :, "observations", "observed_ys"].set(observed_ys)
    return constraints

In [None]:
chm

In [None]:
constraints = make_constraints(chm)
args = (max_fireflies, time_mask,)
key = 12094323
key = jax.random.PRNGKey(key)
samples = jax.jit(SIR(N, K, multifirefly_model, constraints))(key, args)

In [None]:
def scatter_multiple_animation(observed_xs, observed_ys, gt_xs=None, gt_ys=None, gt_blinks=None):
    """
    Basic scatter plot animation with moving points
    """
    # Create figure and axis
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.set_xlim(0, SCENE_SIZE)
    ax.set_ylim(SCENE_SIZE, 0)
    ax.set_title('Scatter Plot Animation')
    
    # Initialize scatter plot
    obs_scatter = ax.scatter([], [], c='red', s=200, alpha=0.9, animated=True)
    gt_scatter = ax.scatter([], [], edgecolors='green', facecolors=None, s=200, alpha=0.4, animated=True)

    # Animation update function
    def update(frame):
        # Generate random points
        xs = observed_xs[:, frame, :].flatten()
        ys = observed_ys[:, frame, :].flatten()
        # Update scatter plot data
        obs_scatter.set_offsets(np.column_stack([xs, ys]))

        if gt_xs is not None:
            xs = [x for x in gt_xs[frame, :] if x > 0]
            ys = [y for y in gt_ys[frame, :] if y > 0]
            gt_scatter.set_offsets(np.column_stack([xs, ys]))

            if gt_blinks is not None:
                colors = ["r" if b else "g" for b in gt_blinks[frame, :]]
                gt_scatter.set_color(colors)

        return obs_scatter, gt_scatter
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, 
        update, 
        frames=observed_xs.shape[1],  # Number of animation frames
        interval=100,  # Milliseconds between frames
        blit=True
    )
    
    return anim

In [None]:
sampled_xs, sampled_ys = get_gt_locations(samples)

gt_blinks = chm["steps", :, "dynamics", :, "blink"].value
anim = scatter_multiple_animation(sampled_xs, sampled_ys, gt_xs, gt_ys, gt_blinks)
HTML(anim.to_jshtml())

Proposal ideas:

Target proposal:

3-way flip:

    1. Is it detected?

        flip 1: prior vs. proposal

            flip 2: closer to current estimated location of firefly X or closer to last seen location of firefly X
            
                - Update velocity, position, blinking, etc...

