In [2]:
%load_ext autoreload
%autoreload 2

import jax
import jax.numpy as jnp
import numpy as np
import genjax
from genjax import Diff
from genjax import ChoiceMapBuilder as C

import matplotlib.pyplot as plt
from IPython.display import display, HTML
from tqdm.notebook import tqdm

import sys
sys.path.append("../src/")
from maskcombinator_model import *
from distributions import *
from config import *
from render import *
from utils import *

genjax.pretty()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
def and_new_keys(key, N):
    key, tmp_key = jax.random.split(key)
    new_keys = jax.random.split(tmp_key, N)
    return key, new_keys

def summary_stats(weights):
    print('mean={mean}, std={std}, min={min}, max={max}'.format(
        mean=jnp.mean(weights), std=jnp.std(weights),
        min=jnp.min(weights), max=jnp.max(weights)))

In [None]:
def make_step_choicemap(observation_chm, step_index):
    pixels = observation_chm["steps", step_index, "observations", "pixels"].value
    return C["steps", step_index, "observations", "pixels"].set(pixels)

def make_masked_combinator_step_update_problem(init_carry, observations, step, num_steps):
    argdiffs = (
        Diff.no_change(init_carry),
        Diff.unknown_change(jnp.arange(num_steps) <= step),
    )
    chm = make_step_choicemap(observations, step)
    return U.g(argdiffs, chm)

In [None]:
NUM_PARTICLES = 100
NUM_STEPS = TIME_STEPS

def sequential_monte_carlo_sampler(key, masked_model, step_init_args, observations_dict):
    key, init_keys = and_new_keys(key, NUM_PARTICLES)
    
    init_chm = make_step_choicemap(observations_dict, 0)
    model_init_args = step_init_args, jnp.arange(NUM_STEPS) <= 0
    init_particles, init_weights = jax.vmap(masked_model.importance, in_axes=(0, None, None))(
        init_keys, C.n(), model_init_args)

    @jax.jit
    def scan_fn(smc_scan_state, scan_input):
        (prev_weights, step_particles) = smc_scan_state
        key, time_step = scan_input
        key, resample_key = jax.random.split(key)
        
        parents = jax.random.categorical(resample_key, prev_weights, shape=(NUM_PARTICLES,))
        step_particles = jax.tree.map(lambda x: x[parents], step_particles)
    
        key, step_keys = and_new_keys(key, NUM_PARTICLES)
        update_problem = make_masked_combinator_step_update_problem(
            step_init_args, observations_dict, time_step, NUM_STEPS)

        # We can overwrite the weights since we've resampled unconditionally. If / when we 
        # modify the code to sample based on ESS, we'll have to be more careful.
        step_particles, step_weights, _, _ = jax.vmap(masked_model.update, in_axes=(0, 0, None))(
            step_keys, step_particles, update_problem)

        return (step_weights, step_particles), step_particles
        
    scan_keys = jax.random.split(key, NUM_STEPS - 1)

    (_, final_particles), particle_history = jax.lax.scan(
        scan_fn, (init_weights, init_particles), (scan_keys, jnp.arange(1, NUM_STEPS)))
    
    return final_particles, particle_history, init_particles



In [None]:
generator = jax.jit(multifirefly_model.importance)
num_particles = 100
num_steps = TIME_STEPS
sampler = jax.jit(sequential_monte_carlo_sampler)

In [None]:
max_fireflies = jnp.arange(1, 5)
key = jax.random.PRNGKey(210)
key, subkey = jax.random.split(key)
constraints = C.d({"n_fireflies": 2})
run_until = jnp.arange(TIME_STEPS) < TIME_STEPS
gt_trace, weight = generator(subkey, constraints, (max_fireflies, run_until,))
gt_choices = gt_trace.get_sample()


In [None]:
frames = get_frames(gt_choices)
anim = animate(frames, 20)
display(HTML(anim.to_html5_video()))

In [None]:
masked_model = multifirefly_model
num_particles = 100
step_init_args = max_fireflies
key, smc_key = jax.random.split(key)
smc_particles, smc_history, init_particles = sampler(smc_key, masked_model, step_init_args, gt_choices)

In [None]:
init_choices = jax.vmap(lambda x: x.get_sample())(init_particles)
n_fireflies = init_choices["n_fireflies"]

In [None]:
history_chm = jax.vmap(lambda tr: tr.get_sample(), in_axes=(0,))(smc_history)
final_chm = jax.vmap(lambda tr: tr.get_sample(), in_axes=(0,))(smc_particles)

In [None]:
def get_latents(trace):
    # n_fireflies = trace["n_fireflies"]
    xs = get_masked_values(trace["steps", ..., "dynamics", ..., "x"], np.nan)
    ys = get_masked_values(trace["steps", ..., "dynamics", ..., "y"], np.nan)
    blinks = get_masked_values(trace["steps", ..., "dynamics", ..., "blink"], np.nan)

    return xs, ys, blinks

def get_normalized_weights(trace, axis=0):
    return jax.nn.softmax(trace.get_score(), axis=axis)

In [None]:
weights = smc_history.get_score() #get_normalized_weights(smc_history, axis=1) # T, N

# scatter plot for the distribution of weights of the particles over time
plt.figure(figsize=(6, 5))
for i in range(weights.shape[1]):
    plt.scatter(np.arange(TIME_STEPS - 1), weights[:, i], c="#FF968A")

plt.xlabel('Time')  
plt.ylabel('Log log likelihood')
plt.title('Distribution of weights of the particles over time')
plt.show()

In [None]:
init_choices["n_fireflies"].shape

In [None]:
n_fireflies = np.concatenate([init_choices["n_fireflies"][np.newaxis, :], history_chm["n_fireflies"]])
fig, ax = plt.subplots(1, 1, figsize=(6, 5))

for i in range(50):
    # Get color from uniform matplotlib colormap
    color = plt.cm.viridis(i / 50)
    ax.bar(i, n_fireflies[i], color=color, label="Particles")

plt.ylim(0, 4.5)
plt.yticks(np.arange(5))
plt.title("Estimated Number of Fireflies")
plt.xlabel("Steps")
plt.ylabel("Number of Fireflies")
# De-duplicate legend
handles, labels = ax.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
plt.legend(by_label.values(), by_label.keys())
plt.show()

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import numpy as np
import matplotlib.colors as mcolors

def visualize_latents_two_panel(particle_trace, gt_trace):
    """
    Visualize latents from a trace using a two-panel layout:
    left panel for observations, right panel for scatter plot.
    Ground truth fireflies are yellow, particles are red.
    """
    gt_choices = gt_trace.get_sample()
    particle_choices = jax.vmap(lambda tr: tr.get_sample(), in_axes=(0,))(particle_trace)
    observations = gt_choices["steps", ..., "observations", "pixels"].value

    gt_xs, gt_ys, gt_blinks = get_latents(gt_choices)  # (T, Nfireflies)
    xs, ys, blinks = get_latents(particle_choices)  # (Nparticles, T, Nfireflies)
    alphas = np.array(get_normalized_weights(particle_trace))

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    fig.suptitle("Firefly Visualization", fontsize=16)
    
    # Left panel - Observations
    cmap = mcolors.LinearSegmentedColormap.from_list("custom", ["black", "yellow"])
    obs_img = ax1.imshow(observations[0], cmap=cmap, vmin=0, vmax=1)
    ax1.set_title("Observations")
    
    # Right panel - Scatter plot
    scatter_gt = ax2.scatter([], [], c="yellow", marker="o", s=100, label="GT", linewidth=2)
    scatter_particles = ax2.scatter([], [], c="red", alpha=0.7, marker="o", s=50, label="Particles", linewidth=2)
    ax2.set_xlim(0, 64)
    ax2.set_ylim(64, 0)  # Invert y-axis to match image coordinates
    ax2.set_title("Firefly Positions")
    ax2.legend(loc="lower left")
    ax2.set_facecolor('lightgray')  # Set a light background for better contrast

    def init():
        return obs_img, scatter_gt, scatter_particles

    def update(frame):
        t = frame
        # Update observation image
        obs_img.set_array(observations[t])
        
        # Update scatter plot
        gt_data = np.column_stack((gt_xs[t], gt_ys[t]))
        particle_data = np.column_stack((xs[:, t, :].flatten(), ys[:, t, :].flatten()))
        
        scatter_gt.set_offsets(gt_data)
        scatter_particles.set_offsets(particle_data)
        
        gt_markers = np.where(gt_blinks[t] == 1, "o", "x")
        particle_markers = np.where(blinks[:, t, :].flatten() == 1, "o", "x")
        
        scatter_gt.set_paths([plt.matplotlib.markers.MarkerStyle(marker).get_path() for marker in gt_markers])
        scatter_particles.set_paths([plt.matplotlib.markers.MarkerStyle(marker).get_path() for marker in particle_markers])
        
        particle_alphas = np.repeat(alphas, xs.shape[2])
        
        gt_colors = np.where(gt_blinks[t, :, np.newaxis] == 1, 
                                 np.array([1., 1., 0., 1.]),
                                 np.array([1., 1., 0., 1.]))
        scatter_gt.set_facecolors(gt_colors)
        
        particle_colors = np.repeat(np.array([[1., 0., 0., 1.]]), particle_alphas.shape[0], axis=0)
        
        np.random.seed(key[0])
        particle_colors[:, :3] = np.random.rand(particle_colors.shape[0], 3)
        particle_colors[:, 3] = particle_alphas 
        scatter_particles.set_facecolors(particle_colors)
        
        ax2.set_title(f"Firefly Positions - Time step {t}")
        return obs_img, scatter_gt, scatter_particles

    anim = animation.FuncAnimation(fig, update, frames=observations.shape[0], init_func=init, blit=True)
    plt.close(fig)  # Prevent display of static plot
    return HTML(anim.to_jshtml())

# Usage
anim = visualize_latents_two_panel(smc_particles, gt_trace)
display(anim)

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML, display
import numpy as np
import matplotlib.colors as mcolors
from tqdm.notebook import tqdm

def visualize_particle_filter_evolution(particle_history, gt_trace, steps_to_show):
    """
    Visualize the evolution of particle filter predictions over time.
    
    Args:
    particle_traces: List of particle traces, one for each timestep of the particle filter
    gt_trace: Ground truth trace
    num_steps_to_show: Array or Number of steps to visualize for each particle filter timestep
    """
    gt_choices = gt_trace.get_sample()
    observations = gt_choices["steps", ..., "observations", "pixels"].value

    def get_latents_for_timestep(trace, timestep):
        choices = jax.vmap(lambda tr: tr.get_sample(), in_axes=(0,))(trace)
        xs, ys, blinks = get_latents(choices)
        return xs[timestep, :, :timestep+1], ys[timestep, :, :timestep+1], blinks[timestep, :, :timestep+1]

    def create_animation_for_timestep(particle_trace, timestep):
        gt_xs, gt_ys, gt_blinks = get_latents(gt_choices)
        xs, ys, blinks = get_latents_for_timestep(particle_trace, timestep)
        alphas = np.array(get_normalized_weights(particle_trace, axis=1))
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
        fig.suptitle(f"Particle Filter Evolution - Timestep {timestep}", fontsize=16)

        cmap = mcolors.LinearSegmentedColormap.from_list("custom", ["black", "yellow"])
        obs_img = ax1.imshow(observations[0], cmap=cmap, vmin=0, vmax=1)
        ax1.set_title("Observations")

        scatter_gt = ax2.scatter([], [], c="yellow", marker="o", s=100, label="GT", linewidth=2)
        scatter_particles = ax2.scatter([], [], c="red", alpha=0.7, marker="o", s=100, label="Particles", linewidth=2)
        ax2.set_xlim(0, 64)
        ax2.set_ylim(64, 0)
        ax2.set_title("Firefly Positions")
        ax2.legend(loc="lower left")
        ax2.set_facecolor('lightgray')

        def update(frame):
            t = frame
            obs_img.set_array(observations[t])
            
            gt_data = np.column_stack((gt_xs[t], gt_ys[t]))
            particle_data = np.column_stack((xs[:, t, :].flatten(), ys[:, t, :].flatten()))
            
            scatter_gt.set_offsets(gt_data)
            scatter_particles.set_offsets(particle_data)
            
            gt_markers = np.where(gt_blinks[t] == 1, "o", "x")
            particle_markers = np.where(blinks[:, t, :].flatten() == 1, "o", "x")
            
            scatter_gt.set_paths([plt.matplotlib.markers.MarkerStyle(marker).get_path() for marker in gt_markers])
            scatter_particles.set_paths([plt.matplotlib.markers.MarkerStyle(marker).get_path() for marker in particle_markers])
            
            particle_alphas = np.repeat(alphas[t], xs.shape[2])
            gt_colors = np.where(gt_blinks[t, :, np.newaxis] == 1, 
                                 np.array([1., 1., 0., 1.]),
                                 np.array([1., 1., 0., 1.]))
            scatter_gt.set_facecolors(gt_colors)
            
            particle_colors = np.repeat(np.array([[1., 0., 0., 1.]]), particle_alphas.shape[0], axis=0)
            
            np.random.seed(0)
            particle_colors[:, :3] = np.random.rand(particle_colors.shape[0], 3)
            particle_colors[:, 3] = particle_alphas 
            scatter_particles.set_facecolors(particle_colors)
            
            ax2.set_title(f"Firefly Positions - Time step {t}")
            return obs_img, scatter_gt, scatter_particles

        anim = animation.FuncAnimation(fig, update, frames=timestep+1, blit=True)
        plt.close(fig)
        return HTML(anim.to_jshtml())

    animations = []
    if isinstance(steps_to_show, int):
        steps_to_show = np.arange(steps_to_show)
    
    for timestep in tqdm(steps_to_show):
        animations.append(create_animation_for_timestep(particle_history, timestep))
    
    return animations


In [None]:
steps = np.arange(1, 49, 5)
animations = visualize_particle_filter_evolution(smc_history, gt_trace, steps_to_show=steps)
for i, anim in enumerate(animations):
    print(f"Particle Filter Timestep {i}")
    display(anim)