# Use Real-NVP to generate the SAP logo from a Gaussian distribution

This notebook uses [Real-NVP](https://arxiv.org/abs/1605.08803), a type of normalizing flow, in order to generate a complex distribution (SAP logo) from a simple one (Gaussian).

It is implemented using [JAX](https://jax.readthedocs.io/) and most of the code is borrowed from Eric Jang's [implementation](https://github.com/ericjang/nf-jax) of normalizing flows.

Here is the gist of changes done to the original implementation to make Real-NVP work for the SAP logo (as opposed to a more simple distribution like the half-moons used in Eric Jang's tutorial):

* The hidden layers in each normalizing flow are wider (128 => 1024 units) and deeper (> 2 stacks of Dense + Activation)
* Activation layer changed from ReLU to LeakyReLU
* Regularization (L2 weight decay) and gradient clipping

## Construct target image as a scatter plot

In [None]:
%matplotlib inline
import os

import jax.numpy as np
from jax import random

import matplotlib.pyplot as plt
import matplotlib.image as mpimg

# Initialize parameters, not committing to a batch shape
rng = random.PRNGKey(0)

In [None]:
img = mpimg.imread('sap-logo.jpg')
plt.imshow(img)

In [None]:
print(img[350,401])
print(img[100,600])

In [None]:
def convert_img_to_scatter(img, stride, hit):
    # Convert an image into a scatter plot
    X = []
    for i in range(0, img.shape[0] - 50, stride):
        for j in range(0, img.shape[1], stride):
            if list(img[img.shape[0]-(i+1),j]) == hit:
                
                
                X.append([j, i])
    return np.array(X)

def sample_n01(N, D=2):
    return random.normal(rng, (N, D))

In [None]:
X = convert_img_to_scatter(img, 2, [25, 119, 209])
print(X.shape)

In [None]:
plt.scatter(X[:, 0], X[:, 1], s=5)

In [None]:
noise = sample_n01(X.shape[0])
X_noisy = X + noise
plt.scatter(X_noisy[:, 0], X_noisy[:, 1], s=5)

In [None]:
def log_prob_n01(x, eps=1e-15):
    return np.sum(-np.square(x)/2 - np.log(np.sqrt(2*np.pi)),axis=-1)

In [None]:
plt.hist(log_prob_n01(X_noisy))

In [None]:
from sklearn.preprocessing import StandardScaler

X_noisy = StandardScaler().fit_transform(X_noisy)
X = StandardScaler().fit_transform(X)
plt.hist(log_prob_n01(X_noisy))

In [None]:
plt.hist(log_prob_n01(X))

In [None]:
plt.scatter(X_noisy[:, 0], X_noisy[:, 1], s=5)

## Real-NVP implementation

In [None]:
from jax.experimental import stax # neural network library
from jax.experimental.stax import Dense, Relu, LeakyRelu

def nvp_forward(net_params, shift_and_log_scale_fn, x, flip=False):
    d = x.shape[-1] // 2
    x1, x2 = x[:, :d], x[:, d:]
    if flip:
        x2, x1 = x1, x2
        
    shift, log_scale = shift_and_log_scale_fn(net_params, x1)
    y2 = x2 * np.exp(log_scale) + shift
    if flip:
        x1, y2 = y2, x1
    y = np.concatenate([x1, y2], axis=-1)
    return y

def nvp_inverse(net_params, shift_and_log_scale_fn, y, flip=False):
    d = y.shape[-1] // 2
    y1, y2 = y[:, :d], y[:, d:]
    if flip:
        y1, y2 = y2, y1
    
    shift, log_scale = shift_and_log_scale_fn(net_params, y1)
    x2 = (y2 - shift) * np.exp(-log_scale)
    if flip:
        y1, x2 = x2, y1
    x = np.concatenate([y1, x2], axis=-1)
    return x, log_scale

def init_nvp():
    """
    Number of layers and hidden units can have a 
    significant effect on performance and final output.
    """
    D = 2
    net_init, net_apply = stax.serial(
       Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, 
       Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, 
       Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, 
       Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, 
       Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, Dense(2048), LeakyRelu, 
       Dense(D)
    )
    in_shape = (-1, D//2)
    out_shape, net_params = net_init(rng, in_shape)
    
    def shift_and_log_scale_fn(net_params, x1):
        s = net_apply(net_params, x1)
        return np.split(s, 2, axis=1)
    
    return net_params, shift_and_log_scale_fn

def sample_nvp(net_params, shift_and_log_scale_fn, base_sample_fn, N, flip=False):
    x = base_sample_fn(N)
    return nvp_forward(net_params, shift_and_log_scale_fn, x, flip)

def log_prob_nvp(net_params, shift_and_log_scale_fn, base_log_prob_fn, y, flip=False):
    x, log_scale = nvp_inverse(net_params, shift_and_log_scale_fn, y, flip)
    ildj = -np.sum(log_scale, axis=-1)
    return base_log_prob_fn(x) + ildj

def init_nvp_chain(n=2):
    flip = False
    ps, configs = [], []
    for i in range(n):
        p, f = init_nvp()
        ps.append(p), configs.append((f, flip))
        flip = not flip
    return ps, configs

def sample_nvp_chain(ps, configs, base_sample_fn, N):
    x = base_sample_fn(N)
    for p, config in zip(ps, configs):
        shift_and_log_scale_fn, flip = config
        x = nvp_forward(p, shift_and_log_scale_fn, x, flip)
    return x

def make_log_prob_fn(p, log_prob_fn, config):
    shift_and_log_scale_fn, flip = config
    return lambda x: log_prob_nvp(p, shift_and_log_scale_fn, log_prob_fn, x, flip)

def log_prob_nvp_chain(ps, configs, base_log_prob_fn, y):
    log_prob_fn = base_log_prob_fn
    for p, config in zip(ps, configs):
        log_prob_fn = make_log_prob_fn(p, log_prob_fn, config)
    return log_prob_fn(y)

## Train

In [None]:
from jax.experimental import optimizers
from jax import jit, grad
import numpy as onp
from tqdm import tqdm

hp = {
    'chains': 15,
    'beta': 1e-3,
    'grad_clip': 1.0,
    'epochs': 5e5,
    'eta': 1e-5,
    'hidden': 2048
}

ps, cs = init_nvp_chain(hp['chains'])

def loss(params, batch):
    return -np.mean(log_prob_nvp_chain(
        params, cs, log_prob_n01, batch)) + hp['beta'] * optimizers.l2_norm(params)

opt_init, opt_update, get_params = optimizers.adam(step_size=hp['eta'])

@jit
def step(i, opt_state, batch):
    params = get_params(opt_state)
    g = grad(loss)(params, batch)
    g = optimizers.clip_grads(g, hp['grad_clip'])
    l = loss(params, batch)
    return opt_update(i, g, opt_state), l
    
iters = int(hp['epochs'])
data_generator = (X_noisy[onp.random.choice(X_noisy.shape[0], 100)] for _ in range(iters))
opt_state = opt_init(ps)
losses = []
for i in tqdm(range(iters)):
    opt_state, l = step(i, opt_state, next(data_generator))
    losses.append(np.log(l))

ps = get_params(opt_state)

plt.plot(losses) 

In [None]:
y = sample_nvp_chain(ps, cs, sample_n01, 3000)
print(y.max())

In [None]:
plt.scatter(y[:, 0], y[:, 1], s=5)

## Animate as GIF

In [None]:
from matplotlib import animation, rc
from IPython.display import HTML, Image

x = sample_n01(3000)
values = [x]
for p, config in zip(ps, cs):
    shift_log_scale_fn, flip = config
    x = nvp_forward(p, shift_log_scale_fn, x, flip=flip)
    values.append(x)
    
# First set up the figure, the axis, and the plot element we want to animate
fig, ax = plt.subplots()

values = [values[0]] * 5 + values + [values[-1]] * 5

y = values[0]
paths = ax.scatter(y[:, 0], y[:, 1], s=3)

In [None]:
n = 12

def animate(i):
    l = i//n
    t = (float(i%n))/n
    y = (1-t)*values[l] + t*values[l+1]
    paths.set_offsets(y)
    return (paths,)

In [None]:
f_name = 'sap_logo_nf_{chains}_{beta}_{eta}_{grad_clip}_{epochs}_{hidden}.gif'.format(**hp)

anim = animation.FuncAnimation(fig, animate, frames=n*(len(cs) + 5 + 5), interval=1, blit=False)
anim.save(f_name, writer='imagemagick', fps=50)

In [None]:
Image(url=f_name)