In [1]:
%load_ext autoreload
from jax import config
config.update("jax_debug_nans", False)
# config.update("jax_disable_jit", True)

import sys
import os
sys.path.append("../../../learning_particle_gradients/")
import json_tricks as json
import copy
from functools import partial

from tqdm import tqdm
import jax.numpy as np
from jax import grad, jit, vmap, random, lax, jacfwd, value_and_grad
from jax import lax
from jax.ops import index_update, index
import matplotlib.pyplot as plt
import numpy as onp
import jax
import pandas as pd
import haiku as hk
from jax.experimental import optimizers


import utils
import metrics
import time
import plot
import stein
import kernels
import distributions
import nets
import models
import flows

from jax.experimental import optimizers

key = random.PRNGKey(0)



In [2]:
# set up exporting
import matplotlib
matplotlib.use("pgf")
matplotlib.rcParams.update({
    "pgf.texsystem": "pdflatex",
#     'font.family': 'serif',
#     'text.usetex': False,
    'pgf.rcfonts': False,
})

figure_path = "/home/lauro/documents/msc-thesis/thesis/figures/"
# save figures by using plt.savefig('title of figure')
# remember that latex textwidth is 5.4in
# so use figsize=[5.4, 4], for example

# Set up embedding

In [2]:
class SetEmbedding(hk.Module):
    def __init__(self, phi_sizes, w_init=hk.initializers.VarianceScaling(2.0), name=None):
        """embed_size: integer, output dimension"""
        super().__init__(name=name)
        self.sizes = phi_sizes
        self.w_init = w_init

    def __call__(self, x):
        """x is a set of shape (n, ...), where n
        is the number of element in the set.

        Computes:
        x1, ..., xn --> phi(x1), ..., phi(xn) --> mean(...)
        """
        n = x.shape[0]
        phi = hk.nets.MLP(output_sizes=self.sizes,
                          w_init=self.w_init,
                          activation=jax.nn.swish,
                          activate_final=True)

        set_embedding = hk.Sequential([
            phi,
            partial(np.mean, axis=0),
        ])
        return set_embedding(x)

# Set up RNN

In [3]:
embedding_dim = 32
hidden_state_size = 2

def embed_fn(x):
    """x is an array of shape (n, d)"""
    e = SetEmbedding([32, 32])(x)
    return hk.nets.MLP([32, embedding_dim])(e)

# embed = hk.transform(embed_fn)


def cell_fn(particles, state):
    """
    particles: particles (input) of shape (n, d)
    state: particle (hidden state) of shape (d,)"""
    cell = hk.GRU(hidden_state_size)
    return cell(embed_fn(particles), state)

cell = hk.transform(cell_fn)

# Get RNN to update particles

In [4]:
target = distributions.Gaussian([0, 0], [1,1])
proposal = distributions.Gaussian([-2, 0], [1,1])

In [5]:
key, subkey = random.split(key)
init_particles = proposal.sample(100)
x = init_particles[0]

print(init_particles.shape)
print(x.shape)


params = cell.init(subkey, init_particles, x)

(100, 2)
(2,)


In [6]:
# single update
key, subkey = random.split(key)
cell.apply(params, subkey, init_particles, x)

(DeviceArray([-1.0212189,  0.7661866], dtype=float32),
 DeviceArray([-1.0212189,  0.7661866], dtype=float32))

In [7]:
particles = init_particles
particles, _ = cell.apply(params, subkey, particles, particles)

In [8]:
def cell_loss(params, subkey, particles):
    def delta(x):
        xnew, _ = cell.apply(params, subkey, particles, x)
        return xnew - x

    sd = stein.stein_discrepancy(particles, target.logpdf, delta)
    return -sd + utils.l2_norm_squared(particles, delta)

In [9]:
cell_loss(params, subkey, particles)

DeviceArray(0.39911443, dtype=float32)

Multi-cell loss: propagate particles through $m$ unrolled cells, and average losses from all.

In [10]:
def multi_cell_loss(params, subkey, particles, unroll_length):
    for i in range(unroll_length):
        ...