# numpyro で VAE 試す
https://zenn.dev/asei/articles/ee0525e452fdb3

<a href="https://colab.research.google.com/gist/AseiSugiyama/2e0211035bd14ebbbe60fdb3b48e438f/numpyro_vae.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
#!pip install --upgrade pip
#!pip install --upgrade numpyro
#!pip install --upgrade jax jaxlib==0.1.56+cuda101 -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [2]:
import inspect
import os
import time

import matplotlib.pyplot as plt

from jax import jit, lax, random
from jax.experimental import stax  # jaxでニューラルネット組むためのモジュール
import jax.numpy as jnp
from jax.random import PRNGKey

import numpyro
from numpyro import optim
import numpyro.distributions as dist
from numpyro.examples.datasets import MNIST, load_dataset
from numpyro.infer import SVI, Trace_ELBO

In [3]:
# VAEはGPUじゃないとめっちゃ遅い
numpyro.set_platform("gpu")

In [4]:
RESULTS_DIR = os.path.abspath(os.path.join(os.path.dirname(inspect.getfile(lambda: None)),
                              '.results'))
os.makedirs(RESULTS_DIR, exist_ok=True)
RESULTS_DIR

'/tmp/ipykernel_3265/.results'

In [5]:
def encoder(hidden_dim, z_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.FanOut(2),  # ノード数を hidden_dim -> 2 にしてるってこと？
        stax.parallel(stax.Dense(z_dim, W_init=stax.randn()),
                      stax.serial(stax.Dense(z_dim, W_init=stax.randn()), stax.Exp)),
    )


In [6]:
def decoder(hidden_dim, out_dim):
    return stax.serial(
        stax.Dense(hidden_dim, W_init=stax.randn()), stax.Softplus,
        stax.Dense(out_dim, W_init=stax.randn()), stax.Sigmoid,
    )


In [7]:
def model(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    decode = numpyro.module('decoder', decoder(hidden_dim, out_dim), (batch_dim, z_dim))
    z = numpyro.sample('z', dist.Normal(jnp.zeros((z_dim,)), jnp.ones((z_dim,))))  # パラメータ
    img_loc = decode(z)
    return numpyro.sample('obs', dist.Bernoulli(img_loc), obs=batch)

In [8]:
def guide(batch, hidden_dim=400, z_dim=100):
    batch = jnp.reshape(batch, (batch.shape[0], -1))
    batch_dim, out_dim = jnp.shape(batch)
    encode = numpyro.module('encoder', encoder(hidden_dim, z_dim), (batch_dim, out_dim))
    z_loc, z_std = encode(batch)
    z = numpyro.sample('z', dist.Normal(z_loc, z_std))
    return z


In [9]:
@jit
def binarize(rng_key, batch):
    return random.bernoulli(rng_key, batch).astype(batch.dtype)

In [10]:
hidden_dim = 400
z_dim = 50
learning_rate = 1.0e-3
batch_size = 256
num_epochs = 100


encoder_nn = encoder(hidden_dim, z_dim)
decoder_nn = decoder(hidden_dim, 28 * 28)
adam = optim.Adam(learning_rate)
svi = SVI(model, guide, adam, Trace_ELBO(), hidden_dim=hidden_dim, z_dim=z_dim)
rng_key = PRNGKey(0)
train_init, train_fetch = load_dataset(MNIST, batch_size=batch_size, split='train')
test_init, test_fetch = load_dataset(MNIST, batch_size=batch_size, split='test')
num_train, train_idx = train_init()
rng_key, rng_key_binarize, rng_key_init = random.split(rng_key, 3)
sample_batch = binarize(rng_key_binarize, train_fetch(0, train_idx)[0])
svi_state = svi.init(rng_key_init, sample_batch)


In [11]:
encoder_nn

(<function jax.experimental.stax.serial.<locals>.init_fun(rng, input_shape)>,
 <function jax.experimental.stax.serial.<locals>.apply_fun(params, inputs, **kwargs)>)

In [12]:
@jit
def epoch_train(svi_state, rng_key):
    def body_fn(i, val):
        loss_sum, svi_state = val
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, train_fetch(i, train_idx)[0])
        svi_state, loss = svi.update(svi_state, batch)
        loss_sum += loss
        return loss_sum, svi_state

    return lax.fori_loop(0, num_train, body_fn, (0., svi_state))

@jit
def eval_test(svi_state, rng_key):
    def body_fun(i, loss_sum):
        rng_key_binarize = random.fold_in(rng_key, i)
        batch = binarize(rng_key_binarize, test_fetch(i, test_idx)[0])
        # FIXME: does this lead to a requirement for an rng_key arg in svi_eval?
        loss = svi.evaluate(svi_state, batch) / len(batch)
        loss_sum += loss
        return loss_sum

    loss = lax.fori_loop(0, num_test, body_fun, 0.)
    loss = loss / num_test
    return loss

In [13]:
%%time
def reconstruct_img(epoch, rng_key):
    img = test_fetch(0, test_idx)[0][0]
    plt.imsave(os.path.join(RESULTS_DIR, 'original_epoch={}.png'.format(epoch)), img, cmap='gray')
    rng_key_binarize, rng_key_sample = random.split(rng_key)
    test_sample = binarize(rng_key_binarize, img)
    params = svi.get_params(svi_state)
    z_mean, z_var = encoder_nn[1](params['encoder$params'], test_sample.reshape([1, -1]))
    z = dist.Normal(z_mean, z_var).sample(rng_key_sample)
    img_loc = decoder_nn[1](params['decoder$params'], z).reshape([28, 28])
    plt.imsave(os.path.join(RESULTS_DIR, 'recons_epoch={}.png'.format(epoch)), img_loc, cmap='gray')

for i in range(num_epochs):
    rng_key, rng_key_train, rng_key_test, rng_key_reconstruct = random.split(rng_key, 4)
    t_start = time.time()
    num_train, train_idx = train_init()
    _, svi_state = epoch_train(svi_state, rng_key_train)
    rng_key, rng_key_test, rng_key_reconstruct = random.split(rng_key, 3)
    num_test, test_idx = test_init()
    test_loss = eval_test(svi_state, rng_key_test)
    reconstruct_img(i, rng_key_reconstruct)
    print("Epoch {}: loss = {} ({:.2f} s.)".format(i, test_loss, time.time() - t_start))

Epoch 0: loss = 192.72499084472656 (3.79 s.)
Epoch 1: loss = 179.18263244628906 (0.10 s.)
Epoch 2: loss = 152.52359008789062 (0.10 s.)
Epoch 3: loss = 135.58538818359375 (0.09 s.)
Epoch 4: loss = 126.55363464355469 (0.09 s.)
Epoch 5: loss = 119.89852905273438 (0.09 s.)
Epoch 6: loss = 115.1008529663086 (0.09 s.)
Epoch 7: loss = 111.9155502319336 (0.09 s.)
Epoch 8: loss = 109.52226257324219 (0.09 s.)
Epoch 9: loss = 107.90716552734375 (0.09 s.)
Epoch 10: loss = 106.52417755126953 (0.09 s.)
Epoch 11: loss = 105.74427032470703 (0.09 s.)
Epoch 12: loss = 104.68888092041016 (0.09 s.)
Epoch 13: loss = 104.14228820800781 (0.09 s.)
Epoch 14: loss = 103.52923583984375 (0.09 s.)
Epoch 15: loss = 102.91914367675781 (0.09 s.)
Epoch 16: loss = 102.47251892089844 (0.09 s.)
Epoch 17: loss = 102.09611511230469 (0.09 s.)
Epoch 18: loss = 101.822998046875 (0.09 s.)
Epoch 19: loss = 101.4325942993164 (0.09 s.)
Epoch 20: loss = 101.27059936523438 (0.09 s.)
Epoch 21: loss = 100.95370483398438 (0.09 s.)
Epo