In [7]:
import os
import sys
import functools
import torch
import tqdm
import numpy as np
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.optim import Adam
from torch.utils.data import DataLoader

In [8]:
# if run locally: add parent path
parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

from models import ScoreNet

<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-dataset" data-toc-modified-id="Load-dataset-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load dataset</a></span></li><li><span><a href="#Diffusion:-Incrementally-add-noise-to-an-image" data-toc-modified-id="Diffusion:-Incrementally-add-noise-to-an-image-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Diffusion: Incrementally add noise to an image</a></span></li><li><span><a href="#Training:-Estimate-the-score" data-toc-modified-id="Training:-Estimate-the-score-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Training: Estimate the score</a></span><ul class="toc-item"><li><span><a href="#Diffusion-coefficient" data-toc-modified-id="Diffusion-coefficient-3.1"><span class="toc-item-num">3.1&nbsp;&nbsp;</span>Diffusion coefficient</a></span></li><li><span><a href="#Marginals" data-toc-modified-id="Marginals-3.2"><span class="toc-item-num">3.2&nbsp;&nbsp;</span>Marginals</a></span></li></ul></li></ul></div>

## Load dataset

In [None]:
# load mnist
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True);
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)

## Diffusion: Incrementally add noise to an image

In [None]:
def perturb(x_0, t, sigma=5):
    ''' Perturb a raw image x_0 with some diffusion noise at level t
        args:
            x_0: np.array, 2D image
            t: int, level of perturbation (from 0 to 1)
            sigma: influences the magnitude of noise
        return:
            x_t: np.array, perturbed image
    '''
    # noise follows Normal(0,I) --> we use randn (rand is for uniform)
    noise = torch.randn_like(x_0)
    
    # s_t depends on a parameter sigma
    s_t = (sigma**(2*t) - 1) / (2*np.log(sigma))
    x_t = x_0 + s_t * noise
    return x_t

images, _ = next(iter(data_loader))
image = images[0]

In [None]:
fig, axs = plt.subplots(1,9, figsize=(20, 2))
for i, t in enumerate(np.linspace(0., 1., num=9)):
    perturbed = perturb(image, t, sigma=5)
    axs[i].imshow(perturbed.permute(1, 2, 0).squeeze()
                  #, vmin=0., vmax=1.
                 )
    axs[i].title.set_text("{:.2f}".format(t))
plt.savefig('progressive_diffusion.pdf')  
plt.show()

## Training: Estimate the score

We perturb the data distribution $p_0$ to our prior $p_T$ using a simple diffusion SDE with parameter $\sigma$:
\begin{align}
d \mathbf{x} = \sigma^t d\mathbf{w}, \quad t\in[0,1]
\end{align}

This follows the general SDE form $d \mathbf{x} = f(\mathbf{x}, t) dt + g(t) d \mathbf{w}$ which has the general conditional linear Gaussian distribution: $p(x_0|x_t) = \mathcal{N}(x_t; \alpha(t)x_0, \beta^2(t)I)$ where $\alpha: [0,1] \rightarrow \mathbb{R}$,  $\beta: [0,1] \rightarrow \mathbb{R}$.

In $\textit{Applied Stochastic Differential Equations}$ by Särkkä and Solin (2019), we learn that $\alpha, \beta$ can be derived analytically from $f(\mathbf{x}, t), g(t)$. In our case, we have:

\begin{align*}
\begin{cases}
  f(\mathbf{x}, t) = 0 \\
  g(t) = \sigma ^t
\end{cases}
\longrightarrow
\begin{cases}
  \alpha(t) = 1 \\
  \beta^2(t) = \frac{\sigma^{2t}-1}{2\log\sigma}
\end{cases}
\end{align*}

Therefore:
\begin{align*}
  p(x_0|x_t) = \mathcal{N}(x_t; x_0, \frac{\sigma^{2t}-1}{2\log\sigma}I)
\end{align*}

### Diffusion coefficient
We define our coefficient following our SDE: $d \mathbf{x} = \sigma^t d\mathbf{w}$.

In [None]:
def diffusion_coeff(t, sigma):
    ''' Define the diffuion coefficient for SDE of choice: g=sigma**t
        args:
            t: torch vector, vector of time steps
            sigma: diffusion parameter in our SDE
        return:
            coeffs: vector of diffusion coefficients
    '''
    coeffs = sigma**t
    return coeffs

### Marginals
We define our mean and variance for $ p(x_0|x_t)$ following Särkkä and Solin: $\mu = x_0$, $Var = \frac{\sigma^{2t}-1}{2\log\sigma}I$

In [None]:
def marginal_prob(t, sigma):
    ''' Compute the mean and standard deviation of p(x(t)|x(0)) for each given
        timestep t, specifically for perturbation f=0, g=sigma**t
        args:
            t: torch vector, vector of time steps
            sigma: diffusion parameter in our SDE
        return:
            (mean, var): parameters of the conditional linear Gaussian
    '''
    mean = 1
    var = (sigma**(2*t) - 1) / (2*np.log(sigma))
    # Yang Song's implem of the NNet expects only var output
    #return mean, var
    return var

The denoising score matching objective is built upon:
\begin{align}
\min_\theta \|s_\theta(\mathbf{x}(t), t) - \nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))\|_2^2
\end{align}

Rewriting the score function: 

$\log p(x_t) = -\frac{1}{2\sigma_t^2} \|x_t\|_2^2 \quad \Rightarrow \quad \nabla_x \log p(x_t) = \frac{1}{\sigma_t^2}x_t = -\frac{1}{\sigma_t}z$

The loss is the L2 norm of the difference between the score function and our learned estimate $s_\theta(\mathbf{x}, t)$:
\begin{align}
\|-\frac{-1}{\sigma_t}z + s_\theta(\mathbf{x}, t)\|_2^2
&= \|\frac{1}{\sigma_t}z + s_\theta(\mathbf{x}, t)\|_2^2 \\
&= \|\frac{1}{\sigma_t}(z + \sigma_t s_\theta(\mathbf{x}, t))\|_2^2 \\
&= \frac{1}{\sigma_t^2} \|z + \sigma_t s_\theta(\mathbf{x}, t)\|_2^2
\end{align}

Noting that $\frac{1}{\sigma_t^2}$ is a scaling constant for our loss and is not of importance for the optimization problem wrt $\theta$, we establish our loss function to compute:

\begin{align}
\text{Loss} = \|z + \sigma_t s_\theta(\mathbf{x}, t)\|_2^2
\end{align}

In [None]:
def loss_function(model, x, marginal_prob_t, eps=1e-5):
    ''' Score-estimation loss function
        args:
            model: torch model for time-dependent score-based generative modeling
            x: torch vector, mini-batch of training data
            marginal_prob_t: mean and variance of perturbed kernel, param=t
            eps: jitter added to noise for numerical stability
        return:
            loss: 
    '''
    # Setup for perturbations at random time steps t
    random_t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps
    z = torch.randn_like(x)
    var = marginal_prob_t(random_t)
    var = var[:, None, None, None]
    
    # Perturb for our SDE's alpha=mean=1, beta^2=var from marginal_prob_t 
    perturbed_x = x + z * var
    
    # Obtain score estimates for our perturbed xs from model
    scores = model(perturbed_x, random_t)
    
    # Compute loss
    loss = torch.mean(torch.sum((scores * var + z)**2, dim=(1,2,3)))
    return loss

In [None]:
device = 'cpu'
sigma = 25.
n_epochs = 10
batch_size = 32
lr = 1e-4

# Pass parameter by default and change prototype
marginal_prob_fn = functools.partial(marginal_prob, sigma=sigma)
diffusion_coeff_fn = functools.partial(diffusion_coeff, sigma=sigma)

# Setup data
dataset = MNIST('.', train=True, transform=transforms.ToTensor(), download=True)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

# Setup model and optimizer
score_model = torch.nn.DataParallel(ScoreNet(marginal_prob_std=marginal_prob_fn))
score_model = score_model.to(device)
optimizer = Adam(score_model.parameters(), lr=lr)

tqdm_epoch = tqdm.notebook.trange(n_epochs)
for epoch in tqdm_epoch:
    avg_loss = 0.
    num_items = 0
    for x, y in data_loader:
        x = x.to(device)    
        loss = loss_function(score_model, x, marginal_prob_fn)
        optimizer.zero_grad()
        loss.backward()    
        optimizer.step()
        avg_loss += loss.item() * x.shape[0]
        num_items += x.shape[0]
    
    tqdm_epoch.set_description('Average Loss: {:5f}'.format(avg_loss / num_items))
    # Update the checkpoint after each epoch of training.
    #torch.save(score_model.state_dict(), 'ckpt.pth')