In [1]:
import jax
import jax.numpy as jnp
import chex
from jax.scipy.stats import norm
from annealed_flow_transport.utils import smc_utils as su
import matplotlib.pyplot as plt

In [2]:
key = jax.random.PRNGKey(5)
key, key_ = jax.random.split(key)

In [3]:
mean_a = 0.
mean_b = -1.
mean_c = 2.
mean_d = 0.1
var_a = 1.
var_b = 4.
var_c = 5.
var_d = 1.

num_particles=10000000

samples = mean_a + jnp.sqrt(var_a)*jax.random.normal(key, (num_particles, 1))

log_density_a = norm.logpdf(samples, mean_a, 
                                      jnp.sqrt(var_a)).flatten()
log_density_b = norm.logpdf(samples, mean_b, 
                                      jnp.sqrt(var_b)).flatten()
log_density_c = norm.logpdf(samples, mean_c, 
                                      jnp.sqrt(var_c)).flatten()
log_density_d = norm.logpdf(samples, mean_d, 
                                    jnp.sqrt(var_d)).flatten()
log_weights = log_density_b - log_density_a

In [4]:
testmean_b = jnp.sum(jax.nn.softmax(log_weights) * samples.flatten())
testvar_b = jnp.sum(jax.nn.softmax(log_weights) * (samples.flatten() - testmean_b)**2)
testmean_b, testvar_b

(Array(-0.9122688, dtype=float32), Array(3.6197534, dtype=float32))

In [5]:
def flow_apply(unused_params, samples):
  transported_samples = jnp.sqrt(var_c/var_b)*(samples-mean_b) + mean_c
  log_det_jacs = jnp.log(jnp.sqrt(var_c/var_b)) * jnp.ones(num_particles)
  return transported_samples, log_det_jacs

In [6]:
testmean_c = jnp.sum(jax.nn.softmax(log_weights) * flow_apply(None, samples)[0].flatten())
testvar_c = jnp.sum(jax.nn.softmax(log_weights) * (flow_apply(None, samples)[0].flatten() - testmean_c)**2)
testmean_c, testvar_c

(Array(2.0980864, dtype=float32), Array(4.5246916, dtype=float32))

In [7]:
def step_density(beta, x):
  log_density_b = norm.logpdf(x, mean_b, jnp.sqrt(var_b)).flatten()
  log_density_d = norm.logpdf(x, mean_d, jnp.sqrt(var_d)).flatten()
  return (1-beta)*log_density_b + beta*log_density_d

In [8]:
def kl_div(mean0, var0, mean1, var1):
  return 0.5 * (
    var0 / var1 + jnp.square(mean1 - mean0) / var1 - 1. + jnp.log(var1) - 
    jnp.log(var0))
kl_div(mean_c, var_c, mean_d, var_d)

Array(3.0002809, dtype=float32, weak_type=True)

In [9]:
def kl(m1, v1, m2, v2):
  return -0.5 * (jnp.log(v1) - jnp.log(v2) - v1/v2 - jnp.square(m1-m2)/v2 + 1)
kl_div(mean_c, var_c, mean_d, var_d)

Array(3.0002809, dtype=float32, weak_type=True)

In [10]:
su.estimate_free_energy(samples, log_weights, flow_apply, None, step_density, 1, 0)

Array(3.00037, dtype=float32)