# Introduction

In this notebook we'll see what are Normalizing Flows exactly and play a bit with a standard implementation. Let's import what we need. We need to have [`pytorch`](https://pytorch.org/get-started/locally/) and [`nflows`](https://github.com/bayesiains/nflows)

In [None]:
# standard python stuff
import os
import sys
import numpy as np
import scipy.stats as st
import matplotlib.pyplot as plt

# stuff for torch+nflows
import torch
from torch import nn
from tqdm import tqdm
from torch.nn.modules import Module

from torch import optim

from nflows.flows.base import Flow
from nflows.distributions.normal import StandardNormal, ConditionalDiagonalNormal
from nflows.transforms.base import CompositeTransform
from nflows.transforms.autoregressive import MaskedAffineAutoregressiveTransform, MaskedPiecewiseRationalQuadraticAutoregressiveTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.nn.nets import ResidualNet

# Transformation rules of probability density functions

The transformation rule of pdfs for a change of variable $x=f(z)$ is

$p_{X}(x)=p_{Z}\left(f^{-1}(x)\right)|\text{det} \nabla_{x}f^{-1}\left(x\right)|$

The determinant of the Jacobian can be rewritten in terms of $f$ for ease of computation as

$p_{X}(x)=p_{Z}\left(f^{-1}(x)\right)|\text{det} \nabla_{x}f^{-1}\left(x\right)|=p_{Z}\left(f^{-1}(x)\right)|\text{det} \nabla_{z}f\left(f^{-1}(x)\right)|^{-1}$

## Example:

Show how you can transform a normally distributed $x\sim \mathcal{N}(0,1)$ to a normally distributed variable $x\sim \mathcal{N}(\mu,\sigma)$ through the change of variables $x = \mu + \sigma z$ by completing this code

In [None]:
### an example of this
N = 10000
pdf_z = st.norm(loc=0,scale=1)
Z = pdf_z.rvs(N)
plt.hist(Z,histtype='step',density=True)
ztoplot = np.linspace(-5,5,100) # dummy variables for plotting the pdf
pdf_vals_z = pdf_z.pdf(ztoplot) # evalute pdf for plotting
plt.plot(ztoplot,pdf_vals_z)
plt.xlabel('z')
plt.ylabel('Density')

In [None]:
mu= 1.0 # arbitrary values
sigma= 0.5 # arbitrary values
X = mu+sigma*zvals
plt.hist(X,histtype='step',density=True)
xtoplot = mu+sigma*ztoplot # dummy variables for plotting the pdf
gradient_xtoplot_over_z = # compute the gradient
pdf_vals_x = pdf_z.pdf(ztoplot) *  ... # evaluate the new pdf using the old pdf + jacobian
plt.plot(xtoplot,pdf_vals_x)
plt.xlabel('x')
plt.ylabel('Density')


# Normalizing Flows

At its essence, Normalizing Flows are bijective functions that map a sample space to a new space where data is distributed however we chose it to. That is, if we have data $x\sim p_{X}$, we want to learn an invertible function $x = f(z,\theta)$ such that $z$ follows an base distribution easy to sample from and to evaluate. The most common choice is a normal distribution $z \sim p_{Z}\equiv \mathcal{N}(0,1)$. 

$f$ will be a learnable neural network with parameters $\theta$ and an easy to compute gradient. The loss function which $\theta$ needs to minimize is nothing more than the negative Log Likelihood obtained using the transformation rule of pdfs

$\mathcal{L}=- \sum_{x\in \mathcal{D}}\text{Ln }p_{X}(x) = - \sum_{x\in \mathcal{D}}\text{Ln }[p_{Z}(f^{-1}(x,\theta))|\text{det }\nabla_{z}f|^{-1}]$

$\mathcal{L}= \sum_{x\in \mathcal{D}}\left(-\text{Ln }[p_{Z}(f^{-1}(x,\theta))]+\text{Ln }[|\text{det }\nabla_{z}f|]\right)$

And assuming a standard normal distribution 

$\mathcal{L}= \sum_{x\in \mathcal{D}}\left(-\text{Ln }\mathcal{N}\left(f^{-1}(x,\theta);0,1\right)+\text{Ln }[|\text{det }\nabla_{z}f|]\right)$

The trick is how to chose a learnable $f$ with easy gradient (which is not a problem using the gradient chain rule with standard NNs + backpropagation) but also easily invertable to go back and forth from $x$ to $z$. 



## Example

In the previous, very simplified example, we know that a good choice of $f(z,\theta)$ is simply $f(z,\theta)=\theta_{0}+\theta_{1}z$ with inverse $f^{-1}(x,\theta)=(x-\theta_{0})/\theta_{1}$ and jacobian $|\text{det }\nabla_{z}f|=|\theta_{1}|$ (which does not depend on the evaluation on $z = (x-\theta_{0})/\theta_{1}$. We can thus simply write the loss function and do a very naive grid minimization

In [None]:
def loss_function(theta0,theta1):
    return np.sum(-st.norm(loc=0,scale=1).logpdf((X-theta0)/theta1)+np.log(theta1))

In [None]:
theta0vals = np.linspace(0.5,1.5,100) # substitute adequate range if you changed mu, sigma before
theta1vals = np.linspace(0.3,0.7,100) # substitute adequate range if you changed mu, sigma before
theta0vals_plot, theta1vals_plot = np.meshgrid(theta0vals,theta1vals)
# print(theta0vals_plot.shape,theta1vals_plot.shape)
loss_function_vals = np.zeros(theta0vals_plot.shape)
for ntheta1val, theta1val in enumerate(theta1vals):
    for ntheta0val, theta0val in enumerate(theta0vals):
        loss_function_vals[ntheta1val,ntheta0val]=loss_function(theta0val,theta1val)
plt.contourf(theta0vals,theta1vals,loss_function_vals,cmap='gist_heat_r')
plt.axhline(sigma)
plt.axvline(mu)
plt.colorbar()

In [None]:
theta0min, theta1min = theta0vals_plot.flatten()[np.argmin(loss_function_vals)],theta1vals_plot.flatten()[np.argmin(loss_function_vals)]
print(theta0min,theta1min)

In [None]:
loss_function(mu,sigma),loss_function(theta0min,theta1min) # why? likely overfitting

In [None]:
plt.hist(X,histtype='step',density=True)
# plt.plot(xtoplot,pdf_vals_x)
# now we use the min parameters explicitly with xtoplot
pdf_vals_x_bis = st.norm(loc=0,scale=1).pdf((xtoplot-theta0min)/theta1min)/theta1min
plt.plot(xtoplot,pdf_vals_x_bis)
plt.xlabel('x')
plt.ylabel('Density')

Note that we **haven't seen the true Z during training**. The technique is aimed at learning $p_{X}(x)$. We did cheat by knowing that the simple parameterization was good enough.

# Choice of f

There are [many](https://arxiv.org/pdf/1908.09257.pdf) ways to do this, but the usual trick consists of concatenating several individual, simpler modules. That is

$z_{1} = f_{1}(z)$

$z_{i} = f_{i-1}(z_{i-1})$ with $i=2,...,n-1$

$x=f_{n}(z_{n-1})$

and having each individual $f_{i}$ module as a simple, invertible function whose parameters are Neural Networks. For example, the very common 