In [None]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import math

In [None]:
# normalize the dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.,), (1,))
])

# load the MNIST dataset, without normalization
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)

# labels
print ("\nLabels:", train_dataset.classes)

# digits of the images
print ("\nClasses:", train_dataset.targets)

# shape of the data tensor
print ("\nData shape:", train_dataset.data.shape)

In [None]:
# number of images to visualize
num_examples = 10

# define a dataloader so that we can get images
train_loader = DataLoader(train_dataset, batch_size=num_examples, shuffle=True)
# get some images
for data in train_loader:
    img, label = data 
    break

# print the shape of the image tensor
print (img.shape)

# label of images
print ("\nLabels:", label)

plt.figure(figsize=(8, 4))
# visualize the images 
for i in range(num_examples):
    plt.subplot(1, num_examples, i + 1)
    plt.imshow(img[i].numpy().reshape(28,28), cmap='gray')
    plt.axis('off')

In [None]:
from torch.utils.data import Subset

# number of images to show 
num_examples = 10

selected_indices = []
# loop over labels
for label in range(10):
    
    # select indices with the matching label
    indices = torch.where(train_dataset.targets == label)[0]
    
    selected_indices.append(indices[0:10])
    
selected_indices = torch.cat(selected_indices)    

# define a dataset only with images with matching label
dataset = Subset(train_dataset, selected_indices)
# shape of the data tensor
print ("\nnumber of images in dataset:", len(dataset))

# define a dataloader for this (sub)-dataset
data_loader = DataLoader(dataset, batch_size=10, shuffle=False)
    
# get some images
for data in data_loader:
    img, labels = data 
    plt.figure(figsize=(4, 2))
    for i in range(num_examples):
        plt.subplot(1, num_examples, i + 1)
        plt.imshow(img[i].numpy().reshape(28,28), cmap='gray')
        plt.axis('off')

In [None]:
class VESDE: 
    def __init__(self, sigma_min, sigma_max, dim=1, T=1):

        self.T = T
        self.dim = dim
        
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
    
    def drift(self, X, t):
        return torch.zeros_like(X)
    
    def diffusion(self, t):

        sigma = self.sigma_min * (self.sigma_max/self.sigma_min) ** (t/self.T) 

        ret = sigma  * torch.sqrt(1.0 / self.T \
                                  * torch.tensor(2 * (math.log(self.sigma_max) - math.log(self.sigma_min))))
        
        return ret

    def marginal_prob(self, X, t):
        mean = X
        std = self.sigma_min * (self.sigma_max/self.sigma_min) ** (t/self.T) 
        return mean, std 

    def prior(self, M):
        return torch.randn(M * self.dim).reshape(-1, self.dim) * self.sigma_max
    
    # sample the SDE using Euler-Maruyama scheme
    def forward_sampling(self, X0, N=100):
        
        if torch.is_tensor(X0) is False:
            X = torch.tensor(X0).reshape(-1, self.dim)
        else :
            X = X0.reshape(-1, self.dim)
                
        traj = [X]
        delta_t = self.T / N

        for i in range(N):

            b = torch.randn_like(X)

            t = i * delta_t * torch.ones(X.shape[0]).reshape(-1, 1)
            
            drift = self.drift(X, t)

            diffusion_coeff = self.diffusion(t)

            X = X + drift * delta_t + diffusion_coeff * math.sqrt(delta_t) * b

            traj.append(X)

        return torch.stack(traj)

    # sample the SDE using Euler-Maruyama scheme
    def backward_sampling(self, X0, model, N=100): 
        
        if torch.is_tensor(X0) is False:
            X = torch.tensor(X0).reshape(-1, self.dim)
        else :           
            X = X0.reshape(-1, self.dim)
        traj = [X]
        delta_t = self.T / N

        for i in range(N):
            
            b = torch.randn_like(X)

            t = self.T - i * delta_t * torch.ones(X.shape[0]).reshape(-1, 1)

            score = model(X, t)
            
            drift = self.drift(X, t)
            diffusion_coeff = self.diffusion(t)

            X = X + (-1.0 * drift + diffusion_coeff**2 * score) * delta_t + math.sqrt(delta_t) * diffusion_coeff * b

            traj.append(X)

        return torch.stack(traj) 

In [None]:
T = 1

sigma_min = 0.03
sigma_max = 2
dim = 28 * 28
sde = VESDE(sigma_min, sigma_max, dim=dim, T=T)

In [None]:
# get some images
for data in data_loader:
    img, label = data 
    break
    
traj = sde.forward_sampling(img)


In [None]:
print (traj.shape)

for step in range(traj.shape[0]):
    if step % 10 == 0 :
        plt.figure(figsize=(4, 2))
        for i in range(10):
            plt.subplot(1, 10, i + 1)
            plt.imshow(traj[step, i, :].reshape(28,28), cmap='gray')
            plt.axis('off')
            if i == 0 : 
                plt.title("step=%d" % step)

In [None]:
class MyScore(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(28*28+1, 128),
            nn.Tanh(),
            nn.Linear(128, 128), 
            nn.Tanh(),                      
            nn.Linear(128, 28*28), 
       )
        
    # define how the output of model is computed given input x
    def forward(self, x, t):
        
        state = torch.cat((x, t), dim=1)
        output = self.net(state)
        
        return output
    
model = MyScore()    

In [None]:
# batch-size
batch_size = 20

# total training epochs
total_epochs = 10

# represent the function g using a neural network
 
# Adam
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

loss_list = []

for epoch in range(total_epochs):   # for each epoch
    
    for idx, data in enumerate(data_loader):  # loop over all mini-batches 
        
        img, label = data
        
        img = img.reshape(-1, dim)
        
        t = torch.rand(img.shape[0]).reshape(-1, 1) * T 
        
        mean, std_t = sde.marginal_prob(img, t)        
        
        z = torch.randn_like(img)       
       
        xt = img + std_t * z
        
        score = model(xt,t) 
        
        loss_batch = 0.5*torch.sum(score**2, dim=1, keepdim=True) \
                     + torch.sum(score * z, dim=1, keepdim=True) / std_t

        loss = torch.mean(loss_batch * std_t**2)
                
        optimizer.zero_grad()
        # gradient step
        loss.backward()
        # update weights
        optimizer.step()
        
        if idx == 0:
            # record the loss    
            loss_list.append(loss.item())  
            if epoch % 100 == 0:
                print ('epoch=%d\n   loss=%.4f' % (epoch, loss.item()))   
                
fig, ax = plt.subplots(1,1, figsize=(5, 4))

ax.plot(loss_list)
ax.set_xlabel('epoch')
ax.set_title('loss vs epoch')             

In [None]:
            
with torch.no_grad():
    # generate a long trajectory 
    X = sde.prior(10)
    
    trajectory = sde.backward_sampling(X, model, N=100)

    
print ("Number of states:", trajectory.shape)

In [None]:
for step in range(trajectory.shape[0]):
    if step % 10 == 0 :
        plt.figure(figsize=(4, 2))
        for i in range(10):
            plt.subplot(1, 10, i + 1)
            plt.imshow(trajectory[step, i, :].reshape(28,28), cmap='gray')
            plt.axis('off')
            if i == 0 : 
                plt.title("step=%d" % step)