# Flow-based generative model 


We use the package **torchdiffeq** for solving ODE. 

### github: 
    https://github.com/rtqichen/torchdiffeq

### install:

```
pip install torchdiffeq
```


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import torch
import torch.nn as nn
import math

from torch.utils.data import DataLoader

import seaborn as sns

from torchdiffeq import odeint

from sklearn.datasets import make_circles

### Define a simple feedforward neural network to model the vector field $u(x,t):\mathbb{R}^d \times [0,1] \rightarrow \mathbb{R}^d$


In [None]:
class VectorField(nn.Module):
    
    def __init__(self, dim):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(dim + 1, 100),
            nn.Tanh(),
            nn.Linear(100, 100), 
            nn.Tanh(),                      
            nn.Linear(100, 100),             
            nn.Tanh(),            
            nn.Linear(100, 100), 
            nn.Tanh(),            
            nn.Linear(100, dim),             
       )
        
    
    def forward(self, x, t):
        
        # combine x and t into one tensor    
        state = torch.cat((x, t), dim=1)
        
        # pass input to the network
        output = self.net(state)
        
        return output

### Training 

1. target density $p_1 = p_{\mathrm{target}}$, 
2. prior density $p_0$. We choose standard Gaussian density.

#### Linear interpolation: 

\begin{equation}
  X_t = (1-t) X_0 + tX_1 \,, \quad \mathrm{where}~ X_0\sim p_0 ~\mathrm{and}~ X_1\sim p_1\,.
\end{equation}

Let $p(\cdot,t)$ be the probability density of $X_t$.

**Idea**: learn an ODE 
\begin{equation}
  \frac{dY_t}{dt} = u(Y_t,t)\,, \quad t\in [0,1]
\end{equation}

such that, when $Y_0\sim p_0$, then $Y_t \sim p(\cdot, t)$ for any $t\in[0,1]$.

**Main theoretical result**

 The probability density $p(x,t)$ of $X_t$ solves the equation
 \begin{equation*}
  \frac{\partial p(x,t)}{\partial t} + \mathrm{div}\Big(\mathbb{E}\big(X_1 - X_0\big|X_t=x \big) p(x,t)\Big) = 0
\end{equation*}

  Therefore, we learn $u(x,t) = \mathbb{E}\big(X_1 - X_0|X_t=x \big)$.
  
**Flow-matching loss**:

\begin{equation}
  \mathrm{Loss}(u) =  \mathbb{E}_{t\sim U[0,1]} \mathbb{E}_{X_0\sim p_0,
  X_1\sim p_1}\Big(\big|u\big((1-t)X_0+tX_1,t\big) -  (X_1 - X_0)\big|^2\Big) \,.
\end{equation}


In [None]:
def training(X, model, learning_rate=1e-3, batch_size=1000, total_epochs=1000):
    
    # determine dimension from training data
    dim = X.shape[1]

    # Adam
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # change the dataset to PyTorch tensor
    dataset = torch.tensor(X, dtype=torch.float32).reshape(-1,dim)

    # define a dataloader
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

    loss_list = []

    for epoch in range(total_epochs):   # for each epoch

        for idx, data in enumerate(data_loader):  # loop over all mini-batches 

            # for each state in mini-batch, uniformaly sample time on [0,1]
            t = torch.rand(data.shape[0], 1)  

            # generate standard Gaussian random variables
            x0 = torch.randn_like(data) 
            
            # compute linear interpolation
            xt = (1-t) * x0 + t * data 
            
            # evaluate the model
            u = model(xt, t) 

            loss = torch.mean(torch.sum((u - (data - x0))**2, dim=1)) 
            
            optimizer.zero_grad()
            
            # gradient step
            loss.backward()

            # update weights
            optimizer.step()

            if idx == 0:
                # record the loss    
                loss_list.append(loss.item())  
                if epoch % 100 == 0:
                    print ('epoch=%d\n   loss=%.4f' % (epoch, loss.item()))   
                    
    return loss_list         

### generate new samples by solving the ODE:

\begin{equation}
  \frac{dY_t}{dt} = u(Y_t,t)\,, \quad t\in [0,1]
  \label{ode-flow}
\end{equation}
such that, when $Y_0\sim p_0$

In [None]:
def generative_ode(model, dim, N, t = torch.linspace(0,1,100)):
    
    # vector field of the ODE is learnt by training 
    def func(t,x):   
        return model(x, torch.ones(x.shape[0], 1) * t)
    
    # sample y0 from prior (standard Gaussian)
    y0 = torch.randn(N*dim).reshape(N, dim)    
    
    # solve ode using the solver from the package torchdiffeq
    sol = odeint(func, y0, t)
    
    return t, sol

## Example 1 ---- 1d dataset 

This dataset is the same as the one we studied for learning eigenfunctions and diffusion models

We generate data by sampling a Brownian dynamics. 

### prepare dataset

1. potential $V$
2. its gradient
3. sampling SDE
4. set parameters

In [None]:
# potential V, one-dimensional
def V(x):
    y1 = x**8
    y2 = 0.8 * np.exp(-80 * x**2)
    y3 = 0.55 * np.exp(-80 * (x-0.5)**2)
    y4 = 0.3 * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# gradient of V
def gradV(x):
    y1 = 8 * x**7 
    y2 = - 0.8 * 160 * x * np.exp(-80 * x**2)
    y3 = - 0.55 * 160 * (x - 0.5) * np.exp(-80 * (x-0.5)**2) 
    y4 = - 0.3 * 160 * (x + 0.5) * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# sample the SDE using Euler-Maruyama scheme
def sample(beta=1.0, dt=0.001, N=10000, seed=42):
    rng = np.random.default_rng(seed=seed)
    X = 0.0
    traj = []
    tlist = []
    for i in range(N):
        traj.append(X)
        tlist.append(dt*i)        
        b = rng.normal()
        X = X - gradV(X) * dt + np.sqrt(2 * dt/beta) * b

    return np.array(tlist), np.array(traj)  

# coefficient in SDE
beta = 2.0
# step-size 
dt = 0.005
# number of sampling steps 
N = 10000
# range of the domain 
xmin, xmax = -1.0, 1.0

### 1.2 sample the SDE and display the trajectory 

**dataset** contains the training data we will use later.

From the figure on the right, we see that our target density has 4 modes.

In [None]:
# sampling SDE
tvec, X = sample(beta, dt=dt, N=N)

# show how many states are sampled
print ('dataset has %d states.\n' % X.shape[0])

fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)

# plot trajectory vs time
ax.plot(tvec, X, alpha=0.7)
ax.set_ylim([xmin, xmax])
ax.set_xlabel(r'time')
ax.set_ylabel(r'x')
ax.set_title('trajectory')

ax1 = fig.add_subplot(1, 2, 2)

# plot empirical density of the data
ax1.hist(X, 50, density=True)

ax1.set_title('impirical density')

plt.show()

### display neural network

write it as a function, so that we can reuse it!

In [None]:
def plot_vf(model):
    x = torch.linspace(-2, 2, 100)
    t = torch.linspace(0, 1, 100)
    xv, tv = torch.meshgrid(x, t, indexing='ij')

    score_xt = model(xv.reshape(-1, 1), tv.reshape(-1,1)).reshape(100,100)

    fig = plt.figure(figsize=(5,4))
    ax = fig.add_subplot(1, 1, 1)

    im = ax.pcolormesh(xv.numpy(), tv.numpy(), score_xt.detach().numpy(), cmap='coolwarm',shading='auto')

    cbar = fig.colorbar(im, ax=ax, shrink=1.0)
    cbar.ax.tick_params(labelsize=15)

    ax.set_xlabel(r'x',fontsize=20)
    ax.set_ylabel(r't',fontsize=20)
    ax.set_title('vector field',fontsize=20)

### model the vector field by a neural network and plot

In [None]:
model = VectorField(dim=1)

# For the moment, the neural network has not been trained.
plot_vf(model)

### training the neural network by flow-matching

In [None]:
# batch-size
batch_size = 1000

# total training epochs
total_epochs = 5000

X = X.reshape(-1,1)

# training 
loss_list = training(X, model, learning_rate=1e-3, batch_size=batch_size, total_epochs=total_epochs)

# plot the evolution of the loss function during training
fig, ax = plt.subplots(1,1, figsize=(5, 4))
ax.plot(loss_list)
ax.set_xlabel('epoch')
ax.set_title('loss vs epoch')    

# plot the learned vector field
plot_vf(model)

### generate new samples by simulating the ODE 

In [None]:
t, sol = generative_ode(model, dim=1, N=10000)

sol = sol.detach().numpy()

### compare the distribution of generated samples and the data distribution

In [None]:
fig,ax = plt.subplots(1,1, figsize=(6, 5))

sns.kdeplot(X[:, :].flatten(), ax=ax, linestyle="-", bw_adjust=0.2, c='b', label='data distribution')
sns.kdeplot(sol[-1,:,:].flatten(), ax=ax, linestyle="--", bw_adjust=0.2, c='b', label='generated')

ax.set_xlabel('x')
ax.legend()
plt.show()

## Example 2 ----  2d dataset

### prepare the dataset

In [None]:
n_samples = 10000

X, Y = make_circles(
    n_samples=n_samples, factor=0.5, noise=0.05, random_state=170
)

fig, ax = plt.subplots(1,1, figsize=(5, 4))

ax.scatter(X[:, 0], X[:, 1])
ax.set_title("dataset")

plt.tight_layout()
plt.show()

print (X.shape)

### model the vector field by a neural network and plot

In [None]:
model = VectorField(dim=2)

### training with flow-matching loss

In [None]:
# batch-size
batch_size = 1000

# total training epochs
total_epochs = 5000

# training 
loss_list = training(X, model, learning_rate=1e-3, batch_size=batch_size, total_epochs=total_epochs)

# plot the evolution of the loss function during training
fig, ax = plt.subplots(1,1, figsize=(5, 4))
ax.plot(loss_list)
ax.set_xlabel('epoch')
ax.set_title('loss vs epoch')    

### generate new samples by simulating the ODE

In [None]:
t, sol = generative_ode(model, dim=2, N=10000)

sol = sol.detach().numpy()

In [None]:
X0 = np.random.randn(X.shape[0] * 2).reshape(-1, 2)

Xt = np.zeros([t.shape[0], X.shape[0], 2])
for i, ti in enumerate(t):
    Xt[i,:,:]= (1-ti) * X0 + ti * X
    
fig, ax = plt.subplots(2,4, figsize=(10, 3))

ax[0,0].scatter(Xt[50, :, 0], Xt[50, :, 1])
ax[0,0].set_title("t=0.5")
ax[1,0].scatter(sol[50, :, 0], sol[50, :, 1])
ax[1,0].set_title("t=0.5")

ax[0,1].scatter(Xt[90, :, 0], Xt[90, :, 1])
ax[0,1].set_title("t=0.9")
ax[1,1].scatter(sol[90, :, 0], sol[90, :, 1])
ax[1,1].set_title("t=0.9")

ax[0,2].scatter(Xt[95, :, 0], Xt[95, :, 1])
ax[0,2].set_title("t=0.95")
ax[1,2].scatter(sol[95, :, 0], sol[95, :, 1])
ax[1,2].set_title("t=0.95")


ax[0,3].scatter(Xt[-1, :, 0], Xt[-1, :, 1])
ax[0,3].set_title("t=1.0")
ax[1,3].scatter(sol[-1, :, 0], sol[-1, :, 1])
ax[1,3].set_title("t=1.0")

plt.tight_layout()
plt.show()

fig, ax = plt.subplots(1,2, figsize=(8, 4))

ax[0].scatter(X[:, 0], X[:, 1])
ax[0].set_title("data")
ax[1].scatter(sol[-1, :, 0], sol[-1, :, 1])
ax[1].set_title("generated")
plt.tight_layout()
plt.show()