# Train a score-based diffusion model to generate a 1d distribution.

We use **seaborn** to plot probability density function. 

For installation and introduction, see:

https://seaborn.pydata.org/

https://seaborn.pydata.org/installing.html




In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import torch
import torch.nn as nn
import math

from torch.utils.data import DataLoader

import seaborn as sns

## Step 1, generate data.

We generate data by sampling a Brownian dynamics. 

We simply reuse the potential $V$ we studied before for Markov state models and learning eigenfunctions.

### 1.1. functions that define potential and sampling, and also set parameters

In [None]:
# potential V, one-dimensional
def V(x):
    y1 = x**8
    y2 = 0.8 * np.exp(-80 * x**2)
    y3 = 0.55 * np.exp(-80 * (x-0.5)**2)
    y4 = 0.3 * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# gradient of V
def gradV(x):
    y1 = 8 * x**7 
    y2 = - 0.8 * 160 * x * np.exp(-80 * x**2)
    y3 = - 0.55 * 160 * (x - 0.5) * np.exp(-80 * (x-0.5)**2) 
    y4 = - 0.3 * 160 * (x + 0.5) * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# sample the SDE using Euler-Maruyama scheme
def sample(beta=1.0, dt=0.001, N=10000, seed=42):
    rng = np.random.default_rng(seed=seed)
    X = 0.0
    traj = []
    tlist = []
    for i in range(N):
        traj.append(X)
        tlist.append(dt*i)        
        b = rng.normal()
        X = X - gradV(X) * dt + np.sqrt(2 * dt/beta) * b

    return np.array(tlist), np.array(traj)  

# coefficient in SDE
beta = 2.0
# step-size 
dt = 0.005
# number of sampling steps 
N = 10000
# range of the domain 
xmin, xmax = -1.0, 1.0

### 1.2 sample the SDE and display the trajectory 

**dataset** contains the training data we will use later.

From the figure on the right, we see that our target density has 4 modes.

In [None]:
# sampling SDE
tvec, dataset = sample(beta, dt=dt, N=N)

# show how many states are sampled
print ('dataset has %d states.\n' % dataset.shape[0])

fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)

# plot trajectory vs time
ax.plot(tvec, dataset, alpha=0.7)
ax.set_ylim([xmin, xmax])
ax.set_xlabel(r'time')
ax.set_ylabel(r'x')
ax.set_title('trajectory')

ax1 = fig.add_subplot(1, 2, 2)

# plot empirical density of the data
ax1.hist(dataset, 50, density=True)

ax1.set_title('impirical density')

plt.show()

## Step 2, define forward and backward process 


#### VESDE 

Given $0 < \eta_{\min}\le \eta_{\max}$, consider
$$dX_t = \sqrt{\beta(t)} dB_t,$$  where 
  \begin{equation}
      \beta(t) = \frac{2\eta^2_{\mathrm{min}}}{T}
      \Big(\frac{\eta_{\mathrm{max}}}{\eta_{\mathrm{min}}}\Big)^{2t/T}
      \ln\Big(\frac{\eta_{\mathrm{max}}}{\eta_{\mathrm{min}}}\Big) , \quad t \in [0,T]\,.
  \end{equation}
The solution is  
  \begin{equation}
     X_t = X_0 + \int_0^t \sqrt{\beta(s)} dB_s 
     \end{equation}

1. When $X_0=x_0$ is fixed, then 
  \begin{equation*}
     X_t \sim \mathcal{N}\Big(x_0, \eta^2(t)\mathbf{1}_d\Big)\,,\quad
\mathrm{where}\quad 
    \eta(t) = \Big(\int_0^t \beta(s) ds\Big)^{\frac{1}{2}} = 
    \eta_{\mathrm{min}}
      \sqrt{\Big(\frac{\eta_{\mathrm{max}}}{\eta_{\mathrm{min}}}\Big)^{2t/T} -1}.
  \end{equation*}

2. When $X_0\sim \mathcal{N}(x_0, \eta^2_{\min}\mathbf{1}_d)$, 
  \begin{equation*}
    X_t = (x_0 + \eta_{\min} z\big) + \int_0^t\sqrt{\beta(s)} dB_s \sim \mathcal{N}(x_0, \tilde{\eta}^2(t)\mathbf{1}_d)
  \end{equation*}
  where $z\sim \mathcal{N}(0,\mathrm{I}_d)$ and 
  \begin{equation*}
    \tilde{\eta}(t) = \sqrt{\eta^2(t) + \eta^2_{\min}} =  \eta_{\mathrm{min}} \Big(\frac{\eta_{\mathrm{max}}}{\eta_{\mathrm{min}}}\Big)^{t/T}.
  \end{equation*}
  
In particular, we have 
\begin{equation*}
\tilde{\eta}(0) = \eta_{\min}, \quad \tilde{\eta}(T) = \eta_{\max}.
\end{equation*}

As for prior, we use 
\begin{equation*}
\mathcal{N}(0, \eta^2 _{\max}\mathbf{1}_d)
\end{equation*}

#### Backward process 

for a general SDE
\begin{equation}
    dX_t = f(X_t, t)\,dt + \sigma(t) dB_t\,, \quad t \in [0,T]\,,    
\end{equation}
the backward process is
\begin{equation}
  \begin{aligned}
    dY_t =& \Big(- f(Y_t, T-t) + \sigma^2(T-t) \nabla \ln p(Y_t,T-t)\Big)\,dt + \sigma(T-t) dB_t\,,  \\
    Y_0 \sim& p(\cdot, T)\,.
  \end{aligned}
\end{equation}
We assume that $p(\cdot, T)$ is close to a prior that is easy to sample from, and we sample $Y_0$ from prior (instead of $p(\cdot,T)$) when simulating the backward process.

We implement the VESDE in a class, which contains:

1. **beta_sqrt**: compute $\sqrt{\beta(t)}$: 

2. **marginal_std**: compute $\tilde{\eta}(t)$: 

3. **prior**: sample from prior $\mathcal{N}(0, \eta^2 _{\max}\mathbf{1}_d)$

4. **forward_sampling**: sample forward process $dX_t = \sqrt{\beta(t)} dB_t$

5. **backward_sampling**: sample backward process 
$dY_t =  \beta(T-t) \nabla\ln p(Y_t, T-t)dt + \sqrt{\beta(T-t)} dB_t$  


In [None]:
class VESDE: 
    def __init__(self, eta_min, eta_max, dim=1, T=1):

        self.T = T
        self.dim = dim
        
        self.eta_min = eta_min
        self.eta_max = eta_max
      
    def beta_sqrt(self, t):

        sigma = self.eta_min * (self.eta_max/self.eta_min) ** (t/self.T) 
        
        # square root of \beta(t)
        ret = sigma  * torch.sqrt(1.0 / self.T \
                                  * torch.tensor(2 * (math.log(self.eta_max) - math.log(self.eta_min))))
        
        return ret
    
    def marginal_std(self, X, t):
        # standard deviation of Gaussian density at time t
        # this is \tilde{\eta}(t) 
        std = self.eta_min * (self.eta_max/self.eta_min) ** (t/self.T) 
        return std 

    def prior(self, M):
        # Gaussian N(0, \eta^2_{max})
        return torch.randn(M).reshape(-1, self.dim) * self.eta_max
    
    # sample the forward process
    def forward_sampling(self, X0, N=100):
        
        if torch.is_tensor(X0) is False:
            X = torch.tensor(X0).reshape(-1, self.dim)
        else :
            X = X0.reshape(-1, self.dim)
            
        traj = [X]
        delta_t = self.T / N

        for i in range(N):

            b = torch.randn_like(X)

            t = i * delta_t * torch.ones(X.shape)
            
            diffusion_coeff = self.beta_sqrt(t)

            X = X + diffusion_coeff * math.sqrt(delta_t) * b

            traj.append(X)

        return torch.stack(traj)

    # sample the backward SDE 
    def backward_sampling(self, X0, model, N=100): 
        
        # change input to torch tensor if necessary, and reshape the tensor
        if torch.is_tensor(X0) is False:
            X = torch.tensor(X0).reshape(-1, self.dim)
        else :           
            X = X0.reshape(-1, self.dim)
            
        traj = [X]
        delta_t = self.T / N

        for i in range(N):
            
            b = torch.randn_like(X)
            
            # reverse the time
            t = self.T - i * delta_t * torch.ones_like(X)
            
            # evaluate the score
            score = model(X, t)
            
            diffusion_coeff = self.beta_sqrt(t)

            X = X + (diffusion_coeff**2 * score) * delta_t + math.sqrt(delta_t) * diffusion_coeff * b

            traj.append(X)

        return torch.stack(traj) 

### specify parameters  

In [None]:
# time interval is [0,T]
T = 1
N = 500
# step-size
dt = T/N
eta_min = 0.03
eta_max = 2

sde = VESDE(eta_min, eta_max, dim=1, T=T)

### We can simulate the forward process starting from dataset.

1. Dataset contains M=10000 states. 

2. For each state, we generate a trajectory in N=500 steps (each trajectory contains N+1 states).  

3. The dimension is dim=1.

In the end, we get a tensor that contains M trajectories. 

The shape of tensor is $[M,N,dim]$.

In [None]:
# change the dataset to PyTorch tensor
dataset = torch.tensor(dataset, dtype=torch.float32).reshape(-1,1)
dataset = torch.randn(5000,2)

# sampling forward process on [0,T] 
forward_traj_set = sde.forward_sampling(dataset, N=N).detach().numpy()

# number of states in dataset
print (dataset.shape)

# trajectory set contains M trajectories. Each has N+1 states of dimension dim=1.
print (forward_traj_set.shape)

### Using the trajectories of forward process, we plot the probability densities $p(x,t)$ at different time $t$.

We use the function seaborn.kdeplot for density estimation.

See: https://seaborn.pydata.org/generated/seaborn.kdeplot.html

In [None]:
fig,ax = plt.subplots(1,1)

# time at which to plot probability densities 
t_list = [0, 0.5, 0.7, 1.0] * T
color_list = ['b', 'y', 'k', 'r', 'gray']

# plot densities for each time t in t_list
for i, t in enumerate(t_list):
    
    # the index corresponding to time t
    t_idx = int (t / dt)
    
    # select states at time t from all trajectories, and change it to 1d tensor
    traj = forward_traj_set[t_idx, :, :].flatten()
    # estimate the density 
    sns.kdeplot(traj, ax=ax, label='t=%.2f' % t, c=color_list[i])

# get samples from prior distribution    
X_prior = sde.prior(20000).flatten()

# verify that the density p(x,T) is close to the prior (solid and dashed lines in red)
sns.kdeplot(X_prior, ax=ax, label='prior', c='r', linestyle='--')
    
plt.legend()
ax.set_xlim(-3, 3)
plt.show()

### Define a neural network to model the score function $u$

In this example, $u: \mathbb{R} \times [0,T] \rightarrow \mathbb{R}$.

Therefore, input dimension is 2 and output dimension is 1.

In [None]:
class MyScore(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.Tanh(),
            nn.Linear(50, 50), 
            nn.Tanh(),                      
            nn.Linear(50, 50),             
            nn.Tanh(),            
            nn.Linear(50, 1), 
       )
        
    
    def forward(self, x, t):
        
        # combine x and t into one tensor    
        state = torch.cat((x, t), dim=1)
        
        # pass input to the network
        output = self.net(state)
        
        return output
    
model = MyScore()    

### display the (untrained) score function.

In [None]:
x = torch.linspace(-2, 2, 100)
t = torch.linspace(0, T, 100)
xv, tv = torch.meshgrid(x, t, indexing='ij')

score_xt = model(xv.reshape(-1, 1), tv.reshape(-1,1)).reshape(100,100)

fig = plt.figure(figsize=(7,5))
ax = fig.add_subplot(1, 1, 1)

im = ax.pcolormesh(xv.numpy(), tv.numpy(), score_xt.detach().numpy(), cmap='coolwarm',shading='auto')

cbar = fig.colorbar(im, ax=ax, shrink=1.0)
cbar.ax.tick_params(labelsize=15)

ax.set_xlabel(r'x',fontsize=20)
ax.set_ylabel(r't',fontsize=20)
ax.set_title('score function',fontsize=20)

### Let's start learning the score function.

#### general result

  \begin{equation}
    \begin{aligned}
      & \mathbb{E}_{t\sim U([0,T])} \mathbb{E}_{x\sim p(\cdot,t)}
      \Big[\frac{1}{2}\big|u(x,t) - \nabla \ln p(x,t)\big|^2 w(t)\Big] \\
    =&  \mathbb{E}_{t\sim U([0,T])} \mathbb{E}_{x_0\sim p_0}
      \mathbb{E}_{x\sim p(\cdot,t|x_0)} \Big[\frac{1}{2}\big|u(x,t) - \nabla \ln
      p(x,t|x_0)\big|^2 w(t)\Big] + C_2  \\
      =&
\mathbb{E}_{t\sim U([0,T])} \mathbb{E}_{x_0\sim p_0}
      \mathbb{E}_{x\sim p(\cdot,t|x_0)}\Big[ \Big(\frac{1}{2} |u(x,t)|^2 - u(x,t)
      \cdot \nabla \ln p(x,t|x_0)\Big) w(t)\Big] + C_3,
    \end{aligned}
  \end{equation}
  
#### practical loss with VESDE 

1. Instead of $p(x,t|x_0)$, we use 
$$\tilde{p}(x,t|x_0) = \big(2\pi \tilde{\eta}^2(t)\big)^{-\frac{d}{2}} \mathrm{e}^{-\frac{1}{2} |x-x_0|^2/ \tilde{\eta}^2(t)}\,, \quad \mathrm{where}\quad \tilde{\eta}(t) =  \eta_{\mathrm{min}} \Big(\frac{\eta_{\mathrm{max}}}{\eta_{\mathrm{min}}}\Big)^{t/T}.$$
  
2. (simulation-free) To get samples $x\sim \tilde{p}(\cdot,t|x_0)$, we can simply use 
$$x= x_0 + \tilde{\eta}(t)z, \quad \mathrm{where} \,,z\sim \mathcal{N}(0,\mathrm{I}_d).$$


3. Choose $w(t) = \tilde{\eta}^2(t)$
  
We obtain the loss

\begin{equation}
      \begin{aligned}
	\mathrm{Loss}(u) =&  
\mathbb{E}_{t\sim U([0,T])} \mathbb{E}_{x_0\sim p_0} \mathbb{E}_{x\sim
	\tilde{p}(\cdot,t|x_0)}\Big[ \Big(\frac{1}{2} |u(x,t)|^2 - u(x,t) \cdot \nabla \ln \tilde{p}(x,t|x_0)\Big) w(t)\Big] \\
	=& \mathbb{E}_{t\sim U([0,T])}
     \mathbb{E}_{x_0\sim p_0} \mathbb{E}_{x\sim \tilde{p}(\cdot,t|x_0)}
     \bigg[\Big(\frac{1}{2} |u(x,t)|^2 +u(x,t) \cdot \frac{x-
     x_0}{\tilde{\eta}(t)^2}\Big) \tilde{\eta}^2(t)\bigg]\\
	=& \mathbb{E}_{t\sim U([0,T])}
	\mathbb{E}_{x_0\sim p_0} \mathbb{E}_{z\sim \mathcal{N}(0,\mathrm{I}_d)} \bigg[\Big(\frac{1}{2} |u(x,t)|^2 + \frac{u(x,t) \cdot z}{\tilde{\eta}(t)}\Big) \tilde{\eta}^2(t)\bigg]
      \end{aligned}
\end{equation}  

In [None]:
# batch-size
batch_size = 2000

# total training epochs
total_epochs = 5000

# Adam
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# define a dataloader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

loss_list = []

for epoch in range(total_epochs):   # for each epoch
    
    for idx, data in enumerate(data_loader):  # loop over all mini-batches 
        
        # for each state in mini-batch, uniformaly sample time on [0,T]
        t = torch.rand(data.shape[0]) * T 
        
        # standard deviation of the Gaussian density at time t
        std_t = sde.marginal_std(data, t)    
        
        # generate standard Gaussian random variables
        z = torch.randn_like(data) 
        
        # get samples distributed according to the Gaussian density at time t
        xt = data + std_t.reshape(-1, 1) * z
        
        # evaluate the model
        score = model(xt, t.reshape(-1,1)) 

        loss = torch.mean((0.5*torch.sum(score**2, dim=1) + torch.sum(score * z, dim=1) / std_t)*std_t**2)
                        
        optimizer.zero_grad()
        # gradient step
        loss.backward()
        
        # update weights
        optimizer.step()
        
        if idx == 0:
            # record the loss    
            loss_list.append(loss.item())  
            if epoch % 200 == 0:
                print ('epoch=%d\n   loss=%.4f' % (epoch, loss.item()))   
                
fig, ax = plt.subplots(1,1, figsize=(5, 4))

ax.plot(loss_list)
ax.set_xlabel('epoch')
ax.set_title('loss vs epoch')             

### display the (trained) score function.

In [None]:
x = torch.linspace(-2, 2, 100)
t = torch.linspace(0, T, 100)
xv, tv = torch.meshgrid(x, t, indexing='ij')

score_xt = model(xv.reshape(-1, 1), tv.reshape(-1,1)).reshape(100,100)

fig = plt.figure(figsize=(7,5))
ax = fig.add_subplot(1, 1, 1)

im = ax.pcolormesh(xv.numpy(), tv.numpy(), score_xt.detach().numpy(), cmap='coolwarm',shading='auto')

cbar = fig.colorbar(im, ax=ax, shrink=1.0)
cbar.ax.tick_params(labelsize=15)

ax.set_xlabel(r'x',fontsize=20)
ax.set_ylabel(r't',fontsize=20)
ax.set_title('score function',fontsize=20)

### After learning the score function, we can generate new samples by simulating the backward process.

In [None]:
# tell PyTorch to disable auto-differentiation            
with torch.no_grad():
    
    # generate samples from prior 
    X = sde.prior(10000)
    # sample backward process with learned score fuunction
    backward_traj_set = sde.backward_sampling(X, model, N=N)
    
print ("shape of the backward trajectory tensor:", backward_traj_set.shape)

### Let's compare the densities between the forward and backward processes at different time

In [None]:
fig,ax = plt.subplots(1,2, figsize=(13, 5))

for i, t in enumerate(t_list):
    t_idx = int (t/dt)
    sns.kdeplot(forward_traj_set[t_idx,:,:].flatten(), ax=ax[0], label='t=%.2f,forward' % t, bw_adjust=0.2, linestyle="-", c=color_list[i])
    b_t_idx = int ((T-t)/dt)
    sns.kdeplot(backward_traj_set[b_t_idx,:, :].flatten(), label='t=%.2f,backward' % t, ax=ax[0], bw_adjust=0.2, linestyle="--", c=color_list[i])

ax[0].set_xlim(-2, 2)
ax[0].legend()

sns.kdeplot(forward_traj_set[0,:, :].flatten(), ax=ax[1], linestyle="-", bw_adjust=0.2, c='b', label='truth')
sns.kdeplot(backward_traj_set[-1,:,:].flatten(), ax=ax[1], linestyle="--", bw_adjust=0.2, c='b', label='generated')
ax[1].set_xlim(-2, 2)
ax[1].legend()
