In [1]:
import jax

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platform_name", "cpu")
import numpy as np
import scipy.stats
from jax import numpy as jnp
from matplotlib.figure import Figure
from tqdm.notebook import trange

%matplotlib inline

In [2]:
import os
import sys

lib_path = os.path.join(os.path.curdir, "src")
sys.path.insert(0, lib_path)

In [3]:
import importlib

In [5]:
import activation
import network
import normal
import random_matrix
import unscented

importlib.reload(normal)
importlib.reload(unscented)

importlib.reload(random_matrix)
importlib.reload(network)

from network import Layer, Network
from normal import Normal
from random_matrix import RandomGaussian, RandomOrthogonalProjection, ZeroMatrix

In [6]:
key = jax.random.PRNGKey(123)

In [16]:
σ = activation.Sinusoid()

# single layer

In [17]:
σ.L(1., 1., 2., 2., 0.)

Array(0., dtype=float64, weak_type=True)

In [18]:
f = Layer(
    2,
    1,
    key=key,
    activation=σ,
    A=RandomOrthogonalProjection(1.0),
    b=RandomOrthogonalProjection(0.0),
    C=RandomOrthogonalProjection(1.0),
    d=RandomOrthogonalProjection(0.0),
)
network = Network(f)

In [19]:
# f._propagate_cov(\m)

In [20]:
dist = Normal(μ=jnp.ones(f.in_size), Σ=1 * jnp.eye(f.in_size))

In [21]:
analytic_output = f(dist, method="analytic")
linear_output = f(dist, method="linear")
unscented_output = f(dist, method="unscented")
mc_output = f._mc_mean_cov(dist, key, 10_000_000)

In [22]:
print("linear mean", linear_output.μ)
print("unscented mean", unscented_output.μ)
print("analytic mean", analytic_output.μ)
print("monte carlo mean", mc_output.μ)

linear mean [0.31762117]
unscented mean [0.0250322]
analytic mean [0.08760158]
monte carlo mean [0.0876176]


In [23]:
print("linear covariance", linear_output.Σ)
print("unscented covariance", unscented_output.Σ)
print("analytic covariance", analytic_output.Σ)
print("monte carlo covariance", mc_output.Σ)

linear covariance [[0.09324568]]
unscented covariance [[0.26455563]]
analytic covariance [[0.4036371]]
monte carlo covariance [[0.40420067]]


In [24]:
print("linear KL", mc_output.kl_divergence(linear_output))
print("unscented KL", mc_output.kl_divergence(unscented_output))
print("analytic KL", mc_output.kl_divergence(analytic_output))

linear KL 0.4141224649626808
unscented KL 0.044033021858611354
analytic KL 4.867850178502664e-07


# Small-variance test

In [None]:
key, key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(12), 5)
input_size = 1
output_size = 1
hidden_size = 10
network = Network(
    Layer(
        input_size,
        output_size,
        key=key1,
        activation=σ,
        A=RandomOrthogonalProjection(),
        b=RandomOrthogonalProjection(),
        C=ZeroMatrix(),
        d=ZeroMatrix(),
    ),
)
dist = Normal(μ=jnp.zeros(input_size), Σ=jnp.eye(input_size))

In [None]:
mc_output = network[0]._mc_mean_cov(dist, key=jax.random.PRNGKey(1), rep=1_000_000)
mc_output.μ, mc_output.Σ

In [None]:
analytic_output = network(dist, method="analytic")
analytic_output.μ, analytic_output.Σ

In [None]:
linear_output = network(dist, method="linear")
linear_output.μ, linear_output.Σ

In [None]:
unscented_output = network(dist, method="unscented")
unscented_output.μ, unscented_output.Σ

# Deep MLP

In [None]:
key, key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(12), 5)
input_size = 2
output_size = 1
hidden_size = 70
network = Network(
    Layer(
        input_size,
        hidden_size,
        key=key1,
        activation=σ,
        A=RandomGaussian(1.0),
        b=RandomGaussian(1.0),
        C=RandomGaussian(1.0),
        d=RandomGaussian(),
    ),
    Layer(
        hidden_size,
        hidden_size,
        key=key2,
        activation=σ,
        A=RandomGaussian(1.0),
        b=RandomGaussian(1.0),
        C=RandomGaussian(1.0),
        d=ZeroMatrix(),
    ),
    Layer(
        hidden_size,
        hidden_size,
        key=key3,
        activation=σ,
        A=RandomGaussian(1.0),
        b=RandomGaussian(1.0),
        C=RandomGaussian(1.0),
        d=ZeroMatrix(),
    ),
    Layer(
        hidden_size,
        output_size,
        key=key4,
        activation=σ,
        A=RandomGaussian(1.0),
        b=RandomGaussian(1.0),
        C=RandomGaussian(1.0),
        d=ZeroMatrix(),
    ),
)

In [None]:
μ = jnp.zeros(input_size)
Σ = jnp.eye(input_size)

In [None]:
num_samples = 2**10

In [None]:
def plot_mc(ax, μ, Σ):
    dist = Normal(μ, Σ)
    unscented_output = network(dist, method="unscented")
    linear_output = network(dist, method="linear")
    analytic_output = network(dist, method="analytic", mean_field=False)
    analytic_mean_field_output = network(dist, method="analytic", mean_field=True)

    input_samples = dist.qmc(num_samples=num_samples)
    output_samples = jax.vmap(network)(input_samples)

    pseudo = Normal(
        jnp.mean(output_samples).reshape(1, -1), jnp.var(output_samples).reshape(1, 1)
    )

    y_mesh = np.linspace(np.min(output_samples), np.max(output_samples), 2000)

    ax.plot(
        y_mesh,
        scipy.stats.gaussian_kde(output_samples.reshape(-1))(y_mesh),
        label="empirical KDE",
    )
    ax.plot(
        y_mesh,
        jax.vmap(pseudo.pdf)(y_mesh),
        label="pseudo-true Gaussian fit",
    )

    ax.plot(
        y_mesh,
        jax.vmap(unscented_output.pdf)(y_mesh),
        label="unscented approximation",
    )

    ax.plot(
        y_mesh,
        jax.vmap(linear_output.pdf)(y_mesh),
        label="linear approximation",
        linestyle="--",
    )

    ax.plot(
        y_mesh,
        jax.vmap(analytic_output.pdf)(y_mesh),
        label="analytic approximation",
        linestyle="--",
    )

    ax.plot(
        y_mesh,
        jax.vmap(analytic_mean_field_output.pdf)(y_mesh),
        label="mean-field analytic approximation",
        linestyle="--",
    )
    ax.legend()

In [None]:
fig = Figure(figsize=(12, 8), dpi=600, constrained_layout=1)
ax = fig.add_subplot(221)
ax.set_title("Covariance scale 1e-5")
plot_mc(ax, μ, Σ * 1e-5)
ax = fig.add_subplot(222)
ax.set_title("Covariance scale 1e-2")
plot_mc(ax, μ, Σ * 1e-2)
ax = fig.add_subplot(223)
ax.set_title("Covariance scale 1e1")
plot_mc(ax, μ, Σ * 1e1)
ax = fig.add_subplot(224)
ax.set_title("Covariance scale 1e4")
plot_mc(ax, μ, Σ * 1e4)

# fig.savefig('figures/example-uq-Phi.pdf')
fig