In [1]:
%matplotlib widget

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import matplotlib.font_manager as font_manager
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams
from scipy.stats import gaussian_kde
import time
from manim import *

import genjax
from genjax import vi
from genjax import gensp
import numpy as np
import numpy as np
from matplotlib import cm

def trunc(values, decs=0):
    return np.trunc(values*10**decs)/(10**decs)

key = jax.random.PRNGKey(314159)
sns.set_theme(style="white")
label_fontsize = 70  # Set the desired font size here

config.background_color = WHITE
config.disable_caching = True
config.pixel_height = 360
config.pixel_width = 360
config.frame_height = 28.0
config.frame_width = 28.0

In [2]:
@genjax.gen
def model():
    x = vi.normal_reparam(0.0, 10.0) @ "x"
    y = vi.normal_reparam(0.0, 10.0) @ "y"
    rs = (x**2 + y**2)
    z = vi.normal_reparam(rs, 0.2 + (rs / 100.0)) @ "z"

data = genjax.choice_map({"z": 5.0})

In [3]:
def importance_samples(key):
    @genjax.gen
    def proposal(tgt):
        x = vi.normal_reparam(0.0, 10.0) @ "x"
        y = vi.normal_reparam(0.0, 10.0) @ "y"
    
    key, sub_key = jax.random.split(key)
    sub_keys = jax.random.split(sub_key, 10000)
    chm_variational = gensp.choice_map_distribution(proposal)
    sir = gensp.CustomImportance(10000, chm_variational)
    scores, v_chm = jax.vmap(sir.random_weighted, in_axes=(0, None))(
        sub_keys, gensp.target(model, (), data)
    )
    chm = v_chm.get_leaf_value()
    
    x, y = chm["x"], chm["y"]
    return (x, y)

def animate_reference_posterior(scene):
    # Create some example (x, y) points
    x_points, y_points = jax.jit(importance_samples)(jax.random.PRNGKey(314159))

    # Set the number of bins
    x_bins = 100
    y_bins = 100

    # Perform kernel density estimate
    xy = np.vstack([x_points, y_points])
    kde = gaussian_kde(xy)
    
    # Define the grid points for the density plot
    x_grid = np.linspace(-10, 10, 100)
    y_grid = np.linspace(-10, 10, 100)
    X, Y = np.meshgrid(x_grid, y_grid)
    positions = np.vstack([X.ravel(), Y.ravel()])
    Z = np.reshape(kde(positions).T, X.shape)

    # Normalize the density values for better visualization
    Z = Z / np.max(Z)

    # Get the viridis colormap from matplotlib
    cmap = plt.get_cmap('viridis')

    # Create a grid of squares to represent the density plot
    squares = VGroup()
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            square = Square(side_length=0.20)
            square.move_to([X[i, j], Y[i, j], 0])
            color = cmap(Z[i, j])
            square.set_fill(rgb_to_color(color), opacity=Z[i, j])
            square.set_stroke(WHITE, width=0.001)
            squares.add(square)

    # Add the grid of squares to the scene
    scene.add(squares)
    
    ring = Circle(radius=jnp.sqrt(5.0), color=RED)
    scene.play(Create(ring))
    text = Text("z = 5.0", color=RED).next_to(ring, 2 * UP).shift(RIGHT)
    scene.play(Write(text))
    #text = Text("High quality approximate posterior", color=BLACK).next_to(ring, 13 * UP)
    #scene.play(Write(text))
    return ring

In [4]:
class ReferencePosterior(Scene):
    def construct(self):
        animate_reference_posterior(self)

%manim -qh ReferencePosterior

  mob.points = mob.points.astype("float")


                                                                                

                                                                                

In [5]:
@genjax.gen
def variational_family(data, ϕ):
    μ1, μ2, log_σ1, log_σ2 = ϕ
    x = vi.normal_reparam(μ1, jnp.exp(log_σ1)) @ "x"
    y = vi.normal_reparam(μ2, jnp.exp(log_σ2)) @ "y"


objective = vi.elbo(model, variational_family, data)

class NaiveGuideTraining(Scene):
    def construct(self):

        # Training.
        key = jax.random.PRNGKey(314159)
        ϕ = (0.0, 0.0, 0.0, 0.0)
        jitted = jax.jit(jax.vmap(objective.value_and_grad_estimate, in_axes=(0, None)))
        colormap = cm.get_cmap('viridis')

        ring = animate_reference_posterior(self)

        def block_train(key, ϕ):
            sub_keys = jax.random.split(key, 64)
            (
                loss,
                (
                    _,
                    (
                        _,
                        ϕ_grads,
                    ),
                ),
            ) = jax.vmap(objective.value_and_grad_estimate, in_axes=(0, None))(sub_keys, ((), (data, ϕ)))
            ϕ = jtu.tree_map(lambda v, g: v + 1e-3 * jnp.mean(g), ϕ, ϕ_grads)
            return loss, ϕ

        
        # Warmup
        jitted = jax.jit(block_train)
        key, sub_key = jax.random.split(key)
        _ = jitted(sub_key, ϕ)
        
        total_time = 0.0
        
        for i in range(0, 1501):
            key, sub_key = jax.random.split(key)
            start = time.time()
            loss, ϕ = jitted(sub_key, ϕ)
            duration = time.time() - start
            if i > 2:
                total_time += duration
            
            # Sampling
            if i > 2 and (i % 100 == 0 or i == 3):
                loss_text = Text(f"(evidence lower bound) loss estimate = {trunc(jnp.mean(loss), 2)}", color=BLACK).next_to(ring, 13 * DOWN)
                self.add(loss_text)
                duration_text = Text(f"(training) wall clock time = {trunc(jnp.mean(total_time), 3)} (s)", color=BLACK).next_to(loss_text, DOWN)
                self.add(duration_text)
                key, sub_key = jax.random.split(key)
                sub_keys = jax.random.split(sub_key, 5000)
                tr = jax.jit(jax.vmap(variational_family.simulate, in_axes=(0, None)))(
                    sub_keys, (data, ϕ)
                )
                chm = tr.strip()
                x_data, y_data = chm["x"], chm["y"]
                scores = jnp.exp(tr.get_score())
                colors = [list(colormap(s)) for s in scores]
                new_dots = VGroup(*[Dot(point=(x, y, 1), radius=0.03, color=BLACK) for x, y, c in zip(x_data, y_data, colors)])
                self.add(new_dots)
                self.wait(0.7)
                self.remove(new_dots)
                self.remove(loss_text)
                self.remove(duration_text)
                
%manim -qh NaiveGuideTraining

  mob.points = mob.points.astype("float")


                                                                                

                                                                                

In [6]:
@genjax.gen
def variational_family(data, ϕ):
    μ1, μ2, log_σ1, log_σ2 = ϕ
    x = vi.normal_reparam(μ1, jnp.exp(log_σ1)) @ "x"
    y = vi.normal_reparam(μ2, jnp.exp(log_σ2)) @ "y"


objective = vi.iwae_elbo(model, variational_family, data, 20)

def importance_sampler_with_guide(key, ϕ):
    N_particles = 20
    
    @genjax.gen
    def hacky_model():
        x = vi.normal_reparam(0.0, 10.0) @ "x"
        y = vi.normal_reparam(0.0, 10.0) @ "y"
        rs = (x**2 + y**2)
        z = vi.normal_reparam(rs, 0.2 + (rs / 100.0)) @ "z"
    
    @genjax.gen
    def hacky_guide(ϕ):
        μ1, μ2, log_σ1, log_σ2 = ϕ
        x = vi.normal_reparam(μ1, jnp.exp(log_σ1)) @ "x"
        y = vi.normal_reparam(μ2, jnp.exp(log_σ2)) @ "y"

    def _importance_sampling(key, ϕ):
        tr = hacky_guide.simulate(key, (ϕ, ))
        choices = tr.get_choices()
        merged = choices.safe_merge(data)
        _, score = hacky_model.assess(key, merged, ())
        return score, score - tr.get_score(), tr.get_choices()

    def importance_sampling(key, ϕ):
        key, sub_key = jax.random.split(key)
        sub_keys = jax.random.split(sub_key, N_particles)
        scores, lws, particles = jax.vmap(_importance_sampling, in_axes=(0, None))(sub_keys, ϕ)
        lmw = jax.scipy.special.logsumexp(lws) - jnp.log(N_particles)
        idx = genjax.categorical.sample(key, lws)
        particle = jtu.tree_map(lambda v: v[idx], particles)
        particle_score = scores[idx]
        return particle_score - lmw, particle
    
    sub_keys = jax.random.split(key, 5000)
    scores, particles = jax.jit(jax.vmap(importance_sampling, in_axes=(0, None)))(sub_keys, ϕ)
    x, y = particles["x"], particles["y"]
    return x, y, scores

class IWELBONaiveGuideTraining(Scene):
    def construct(self):

        # Training.
        key = jax.random.PRNGKey(314159)
        ϕ = (0.0, 0.0, 3.0, 3.0)
        jitted = jax.jit(jax.vmap(objective.value_and_grad_estimate, in_axes=(0, None)))
        colormap = cm.get_cmap('viridis')

        ring = animate_reference_posterior(self)

        def block_train(key, ϕ):
            sub_keys = jax.random.split(key, 64)
            (
                loss,
                (
                    _,
                    (
                        _,
                        ϕ_grads,
                    ),
                ),
            ) = jax.vmap(objective.value_and_grad_estimate, in_axes=(0, None))(sub_keys, ((), (data, ϕ)))
            ϕ = jtu.tree_map(lambda v, g: v + 1e-3 * jnp.mean(g), ϕ, ϕ_grads)
            return loss, ϕ

        
        # Warmup
        jitted = jax.jit(block_train)
        key, sub_key = jax.random.split(key)
        _ = jitted(sub_key, ϕ)
        
        total_time = 0.0
        
        for i in range(0, 5001):
            key, sub_key = jax.random.split(key)
            start = time.time()
            loss, ϕ = jitted(sub_key, ϕ)
            duration = time.time() - start
            if i > 2:
                total_time += duration
            
            # Sampling
            if i > 2 and (i % 250 == 0 or i == 3):
                loss_text = Text(f"(evidence lower bound) loss estimate = {trunc(jnp.mean(loss), 2)}", color=BLACK).next_to(ring, 13 * DOWN)
                self.add(loss_text)
                duration_text = Text(f"(training) wall clock time = {trunc(jnp.mean(total_time), 3)} (s)", color=BLACK).next_to(loss_text, DOWN)
                self.add(duration_text)
                
                key, sub_key = jax.random.split(key)
                x_data, y_data, scores = jax.jit(importance_sampler_with_guide)(key, ϕ)
                scores = jnp.exp(scores)
                colors = [list(colormap(s)) for s in scores]
                new_dots = VGroup(*[Dot(point=(x, y, 1), radius=0.03, color=BLACK) for x, y, c in zip(x_data, y_data, colors)])
                self.add(new_dots)
                self.wait(0.7)
                self.remove(new_dots)
                self.remove(loss_text)
                self.remove(duration_text)
                
%manim -qh IWELBONaiveGuideTraining

  mob.points = mob.points.astype("float")


                                                                                

                                                                                

In [7]:
@genjax.gen
def expressive_variational_family(data, ϕ):
    u = vi.uniform() @ "u"
    θ = 2 * jnp.pi * u
    (r1, r2, p, log_σ1, log_σ2, _log_σ1, _log_σ2) = ϕ
    p = jax.nn.sigmoid(p)
    v = vi.flip_reinforce(p) @ "v"
    r = jax.lax.select(v, r1, r2)
    log_σ1 = jax.lax.select(v, log_σ1, _log_σ1)
    log_σ2 = jax.lax.select(v, log_σ2, _log_σ2)
    x = vi.normal_reparam(r * jnp.cos(θ), jnp.exp(log_σ1)) @ "x"
    y = vi.normal_reparam(r * jnp.sin(θ), jnp.exp(log_σ2)) @ "y"


marginal_q = vi.marginal(
    genjax.select("x", "y", "v"), expressive_variational_family, lambda: vi.sir(5)
)

data = genjax.choice_map({"z": 5.0})
hvi_objective = vi.elbo(model, marginal_q, data)

class ExpressiveGuideTraining(Scene):
    def construct(self):
        ring = animate_reference_posterior(self)
        colormap = cm.get_cmap('viridis')

        # Training.
        key = jax.random.PRNGKey(314159)
        ϕ = (1.0, 5.0, 0.5, 0.0, 0.0, 0.0, 0.0)
        total_time = 0.0

        def block_train(key, ϕ):
            key, sub_key = jax.random.split(key)
            sub_keys = jax.random.split(sub_key, 64)
            loss, (_, ((_, ϕ_grads), ())) = jax.vmap(hvi_objective.value_and_grad_estimate, in_axes=(0, None))(sub_keys, ((), ((data, ϕ), ())))
            ϕ = jtu.tree_map(lambda v, g: v + 1e-3 * jnp.mean(g), ϕ, ϕ_grads)
            return loss, ϕ
        
        # Warmup
        jitted = jax.jit(block_train)
        key, sub_key = jax.random.split(key)
        loss, ϕ = jitted(sub_key, ϕ)
        
        for i in range(0, 1501):
            key, sub_key = jax.random.split(key)
            start = time.time()
            loss, ϕ = jitted(sub_key, ϕ)
            duration = time.time() - start
            if i > 0:
                total_time += duration
            
            # Sampling
            if i > 0 and (i % 100 == 0 or i == 1):
                loss_text = Text(f"(evidence lower bound) loss estimate = {trunc(jnp.mean(loss), 2)}", color=BLACK).next_to(ring, 13 * DOWN)
                self.add(loss_text)
                duration_text = Text(f"(training) wall clock time = {trunc(jnp.mean(total_time), 3)} (s)", color=BLACK).next_to(loss_text, DOWN)
                self.add(duration_text)
                key, sub_key = jax.random.split(key)
                sub_keys = jax.random.split(sub_key, 5000)
                score_estimates, v_chm = jax.jit(jax.vmap(marginal_q.random_weighted, in_axes=(0, None, None)))(
                    sub_keys, (data, ϕ), ()
                )
                
                chm = v_chm.get_leaf_value()
                x_data, y_data = chm["x"], chm["y"]
                score_estimates = jnp.exp(score_estimates)
                colors = [list(colormap(s)) for s in score_estimates]
                new_dots = VGroup(*[Dot(point=(x, y, 1), radius=0.03, color=BLACK) for x, y, c in zip(x_data, y_data, colors)])
                self.add(new_dots)
                #self.play(Create(new_dots))
                self.wait(0.7)
                self.remove(new_dots)  # Remove previous dots from the scene
                self.remove(loss_text)
                self.remove(duration_text)

%manim -qh ExpressiveGuideTraining

  mob.points = mob.points.astype("float")


                                                                                

                                                                                

In [8]:
@genjax.gen
def expressive_variational_family(data, ϕ):
    z = data["z"]
    u = vi.uniform() @ "u"
    θ = 2 * jnp.pi * u
    r, log_σ1, log_σ2 = ϕ
    x = vi.normal_reparam(r * jnp.cos(θ), jnp.exp(log_σ1)) @ "x"
    y = vi.normal_reparam(r * jnp.sin(θ), jnp.exp(log_σ2)) @ "y"

class ExpressiveSampler(Scene):
    def construct(self):
        ring = Circle(radius=jnp.sqrt(5.0), color=RED)
        self.play(Create(ring))
        text = Text("z = 5.0", color=BLACK).next_to(ring, UP).shift(RIGHT)
        self.play(Write(text))
        key = jax.random.PRNGKey(314159)
        ϕ = (4.0, -0.5, -0.5)
        ring2 = Circle(radius=ϕ[0], color=RED)
        for i in range(10):
            key, sub_key = jax.random.split(key)
            tr = expressive_variational_family.simulate(sub_key, (data, ϕ))
            chm = tr.get_choices()
            u, x_data, y_data = chm["u"], chm["x"], chm["y"]
            v = u * 2 * jnp.pi
            arrow = Arrow(start=ORIGIN, end=ring2.point_at_angle(v), color=BLACK)
            self.play(Create(arrow))
            #text = Text("v", color=BLACK).next_to(arrow.get_center(), UP)
            #self.play(Write(text))
            dot = Dot(point=(x_data, y_data, 1), color=BLACK)
            self.play(Create(dot))
            self.wait(0.3)
            self.remove(dot)
            self.remove(arrow)

In [9]:
%manim -qh ExpressiveSampler

  mob.points = mob.points.astype("float")


                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                

                                                                                