# Normalizing flows with Real NVP model


We use the package **normflows** which implements normalizing flow models

### github: 
    https://github.com/VincentStimper/normalizing-flows

### install:

```
pip install normflows

```


This notebook is based on [real_nvp_colab.ipynb](https://github.com/VincentStimper/normalizing-flows/blob/master/examples/real_nvp_colab.ipynb) in the **normflows** package, which you can run direclty on [colab](https://colab.research.google.com/github/VincentStimper/normalizing-flows/blob/master/examples/real_nvp_colab.ipynb)

## Summary of the theory

**Idea**: Express data $x\in \mathbb{R}^d$ as a transformation of (Gaussian) $z\sim p_z(z)$.

**Change of variable formula**

Let $z\sim p_z(z)$ and $x=f(z)$, where $f: \mathbb{R}^d\rightarrow \mathbb{R}^d$ is invertible and differentiable. Then,  
$$
  p_x(x) = p_z(f^{-1}(x)) |\det J_{f^{-1}}(x)|\,.
$$

**Setup**:
1. target $p^*_x(x)$ 
2. transformation $x=f(z;\theta)$. 

**Goal**: find $\theta$ such that $p_x(x;\theta)$ is close to $p^*_x(x)$.
  
  
**(forward) KL divergence**

\begin{equation*}
  \begin{aligned}
    & D_{KL}\Big(p^*_x(x)~|~p_x(x;\theta)\Big) \\
    =& \mathbb{E}_{x\sim p^*_x(x)} \Big(\log \frac{p^*_x(x)}{p_x(x;\theta)}\Big) \\
    =& -\mathbb{E}_{x\sim p^*_x(x)} \Big[\log p_z(f^{-1}(x;\theta)) + \log |\det J_{f^{-1}}(x;\theta)|\Big] + C\,.
  \end{aligned}
\end{equation*}

**Loss function**

  $$\mathrm{Loss}(\theta) = -\frac{1}{N}\sum_{n=1}^N \Big[\log p_z(f^{-1}(x_n;\theta)) + \log |\det J_{f^{-1}}(x_n;\theta)|\Big]\,.
  $$
  
**Real NVP**

$x=(x_1, x_2) = f(z)$ is defined as 

\begin{equation*}
      \begin{aligned}
      x_1 =& z_1\,, \\
	x_2 =& \exp(\sigma_\theta(z_1)) \odot z_2 + \mu_\theta(z_1)\,, 
      \end{aligned}
\end{equation*}
where $\odot$ denotes elementwise product, and $\sigma_\theta, \mu_\theta: \mathbb{R}^{d'}\rightarrow \mathbb{R}^{d-d'}$. The inverse of $f$ is 

\begin{equation*}
  \begin{aligned}
  z_1 =& x_1\,, \\
z_2 =& \exp(-\sigma_\theta(x_1)) \odot (x_2 - \mu_\theta(x_1))\,, 
  \end{aligned}
\end{equation*}

The Jacobian determinant is $\det J_{f^{-1}}(x) = \exp(-\sum_{j=1}^{d-d'}(\sigma_\theta(x_1))_j)$.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import torch
import torch.nn as nn
import math
from tqdm import tqdm
from torch.utils.data import DataLoader

### set up the normalizing flow model

The num_layers is set to 32 in the original notebook. 

Run this notebook with:

1. num_layers = 2
2. num_layers = 32

Compare the results.

In [None]:
import normflows as nf

# Define 2D Gaussian base distribution
base = nf.distributions.base.DiagGaussian(2)

# Define list of flows
num_layers = 2
flows = []
for i in range(num_layers):
    # Neural network with two hidden layers having 64 units each
    # Last layer is initialized by zeros making training more stable
    param_map = nf.nets.MLP([1, 64, 64, 2], init_zeros=True)
    # Add flow layer
    flows.append(nf.flows.AffineCouplingBlock(param_map))
    # Swap dimensions
    flows.append(nf.flows.Permute(2, mode='swap'))
    
# Construct flow model
model = nf.NormalizingFlow(base, flows)

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print ('number of parameters: %d' % params)

### display the flow model

In [None]:
print (model)

### display target distribution

In [None]:
# Define target distribution
target = nf.distributions.TwoMoons()

# make a grid
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
# get grid points
zz = torch.stack((xx, yy), dim=2).reshape(-1, 2)              

# compute the log of density at points zz
log_prob = target.log_prob(zz).reshape(xx.shape[0], xx.shape[1])
# compute density from its log
prob = torch.exp(log_prob)

plt.figure(figsize=(5, 5))
plt.pcolormesh(xx, yy, prob.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()

### Plot initial flow distribution

In [None]:
# compute log of the density at points zz
log_prob = model.log_prob(zz).reshape(xx.shape[0], xx.shape[1])
# compute density from its log
prob = torch.exp(log_prob)

plt.figure(figsize=(5, 5))
plt.pcolormesh(xx, yy, prob.detach().numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()

In [None]:
# Train model
max_iter = 4000
num_samples = 512
show_iter = 500

loss_hist = np.array([])

optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=1e-5)

for it in tqdm(range(max_iter)):
    
    optimizer.zero_grad()
    
    # Get training samples
    x = target.sample(num_samples)
    
    # Compute loss
    loss = model.forward_kld(x)
    
    # Do backprop and optimizer step
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer.step()
    
    # Log loss
    loss_hist = np.append(loss_hist, loss.item())
    
    # Plot learned distribution
    if (it + 1) % show_iter == 0:
        # compute the log of density
        log_prob = model.log_prob(zz)
        # compute density from its log
        prob = torch.exp(log_prob.reshape(xx.shape[0], xx.shape[1]))
        prob[torch.isnan(prob)] = 0

        plt.figure(figsize=(5, 5))
        plt.pcolormesh(xx, yy, prob.detach().numpy(), cmap='coolwarm')
        plt.gca().set_aspect('equal', 'box')
        plt.show()

### plot the loss

In [None]:
# Plot loss
plt.figure(figsize=(5, 5))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

### compare the generated density with the target

In [None]:
# Plot target distribution
f, ax = plt.subplots(1, 2, sharey=True, figsize=(11, 5))

log_prob = target.log_prob(zz).reshape(xx.shape[0], xx.shape[1])
prob = torch.exp(log_prob)

ax[0].pcolormesh(xx, yy, prob.detach().numpy(), cmap='coolwarm')
ax[0].set_aspect('equal', 'box')
ax[0].set_axis_off()
ax[0].set_title('Target', fontsize=24)

# Plot learned distribution
log_prob = model.log_prob(zz).reshape(xx.shape[0], xx.shape[1])
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0
ax[1].pcolormesh(xx, yy, prob.detach().numpy(), cmap='coolwarm')

ax[1].set_aspect('equal', 'box')
ax[1].set_axis_off()
ax[1].set_title('Real NVP', fontsize=24)

plt.show()