In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors

import torch
import torch.nn as nn
import math

from torch.utils.data import DataLoader

import seaborn as sns

In [None]:
# potential V, one-dimensional
def V(x):
    y1 = x**8
    y2 = 0.8 * np.exp(-80 * x**2)
    y3 = 0.55 * np.exp(-80 * (x-0.5)**2)
    y4 = 0.3 * np.exp(-80 * (x+0.5)**2)

    y = 2 * (y1 + y2 + y3 + y4)

    return y

# gradient of V
def gradV(x):
    y1 = 8 * x**7 
    y2 = - 0.8 * 160 * x * np.exp(-80 * x**2)
    y3 = - 0.55 * 160 * (x - 0.5) * np.exp(-80 * (x-0.5)**2) * 1.0
    y4 = - 0.3 * 160 * (x + 0.5) * np.exp(-80 * (x+0.5)**2) * 1.0

    y = 2 * (y1 + y2 + y3 + y4)

    return y

In [None]:
# coefficient in SDE
beta = 2.0
# step-size 
dt = 0.005
# number of sampling steps 
N = 10000
# range of the domain 
xmin, xmax = -1.0, 1.0

In [None]:
# sample the SDE using Euler-Maruyama scheme
def sample(beta=1.0, dt=0.001, N=10000, seed=42):
    rng = np.random.default_rng(seed=seed)
    X = 0.0
    traj = []
    tlist = []
    for i in range(N):
        traj.append(X)
        tlist.append(dt*i)        
        b = rng.normal()
        X = X - gradV(X) * dt + np.sqrt(2 * dt/beta) * b

    return np.array(tlist), np.array(traj)  

In [None]:
tvec, traj = sample(beta, dt=dt, N=N)

print ('Trajectory has %d states.\n' % traj.shape[0])

In [None]:
fig = plt.figure(figsize=(12,4))
ax = fig.add_subplot(1, 2, 1)

# plot trajectory vs time
ax.plot(tvec, traj, alpha=0.5)
ax.set_ylim([xmin, xmax])
ax.set_xlabel(r'time')
ax.set_ylabel(r'x')
ax.set_title('trajectory')

ax1 = fig.add_subplot(1, 2, 2)

# plot empirical density of the trajectory data
ax1.hist(traj, 50, density=True, label='empirical density')

ax1.set_title('impirical and invariant density')
ax1.legend()

In [None]:
class VESDE: 
    def __init__(self, sigma_min, sigma_max, dim=1, T=1):

        self.T = T
        self.dim = dim
        
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max
    
    def drift(self, X, t):
        return torch.zeros_like(X)
    
    def diffusion(self, t):

        sigma = self.sigma_min * (self.sigma_max/self.sigma_min) ** (t/self.T) 

        return sigma  * torch.sqrt(1.0 / self.T * torch.tensor(2 * (math.log(self.sigma_max) - math.log(self.sigma_min))))

    def marginal_prob(self, X, t):
        mean = X
        std = self.sigma_min * (self.sigma_max/self.sigma_min) ** (t/self.T) 
        return mean, std 

    def prior(self, M):
        return torch.randn(M).reshape(-1, self.dim) * self.sigma_max
    
    # sample the SDE using Euler-Maruyama scheme
    def forward_sampling(self, X0, N=100):
        
        X = torch.tensor(X0).reshape(-1, self.dim)
        traj = [X]
        delta_t = self.T / N

        for i in range(N):

            b = torch.randn_like(X)

            t = i * delta_t * torch.ones(X.shape)
            
            drift = self.drift(X, t)

            diffusion_coeff = self.diffusion(t)

            X = X + drift * delta_t + diffusion_coeff * math.sqrt(delta_t) * b

            traj.append(X)

        return torch.stack(traj)

    # sample the SDE using Euler-Maruyama scheme
    def backward_sampling(self, X0, model, N=100): 
        
        X = torch.tensor(X0).clone().detach().reshape(-1, self.dim)
        traj = [X]
        delta_t = self.T / N

        for i in range(N):
            
            b = torch.randn_like(X)

            t = self.T - i * delta_t * torch.ones_like(X)
            score = model(X, t)
            
            drift = self.drift(X, t)
            diffusion_coeff = self.diffusion(t)

            X = X + (-1.0 * drift + diffusion_coeff**2 * score) * delta_t + math.sqrt(delta_t) * diffusion_coeff * b

            traj.append(X)

        return torch.stack(traj) 
    


In [None]:
T = 1

sigma_min = 0.05
sigma_max = 2

sde = VESDE(sigma_min, sigma_max, dim=1, T=T)

In [None]:
XT = sde.forward_sampling(traj, N=1000).detach().numpy()
print (XT.shape)

In [None]:
fig,ax = plt.subplots(1,1)

index_list = [0, 500, 700, 1000]
color_list = ['b', 'y', 'k', 'r', 'gray']
i=0
for idx in index_list:
    t = idx * T / 1000
    sns.kdeplot(XT[idx,:, 0], ax=ax, label='t=%.2f' % t, c=color_list[i])
    i += 1
    
X = sde.prior(20000)
sns.kdeplot(X, ax=ax, label='prior')
    
plt.legend()
ax.set_xlim(-3, 3)

In [None]:
class MyScore(nn.Module):
    
    def __init__(self):
        super().__init__()
        
        self.net = nn.Sequential(
            nn.Linear(2, 50),
            nn.SiLU(),
            nn.Linear(50, 50), 
            nn.SiLU(),
            nn.Linear(50, 50), 
            nn.SiLU(),                        
            nn.Linear(50, 50),             
            nn.SiLU(),            
            nn.Linear(50, 1), 
       )
        
    # define how the output of model is computed given input x
    def forward(self, x, t):
        
        state = torch.cat((x, t), dim=1)

        output = self.net(state)
        
        return output
    
model = MyScore()    

In [None]:
# batch-size
batch_size = 2000

# total training epochs
total_epochs = 5000

# represent the function g using a neural network
 
# Adam
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

traj_data = torch.tensor(traj, dtype=torch.float32).reshape(-1,1)

data_loader = DataLoader(traj_data, batch_size=batch_size, shuffle=True, drop_last=True)

loss_list = []

for epoch in range(total_epochs):   # for each epoch
    
    for idx, data in enumerate(data_loader):  # loop over all mini-batches 

        t = torch.rand(data.shape[0]).reshape(-1, 1) * T 
        
        mean, std_t = sde.marginal_prob(data, t)        
        
        z = torch.randn_like(data)       
       
        xt = data + std_t * z
        
        score = model(xt,t) 

        loss = torch.mean((0.5*score**2 + score * z / std_t)*std_t**2)
                
        optimizer.zero_grad()
        # gradient step
        loss.backward()
        # update weights
        optimizer.step()
        
        if idx == 0:
            # record the loss    
            loss_list.append(loss.item())  
            if epoch % 100 == 0:
                print ('epoch=%d\n   loss=%.4f' % (epoch, loss.item()))   
                
fig, ax = plt.subplots(1,1, figsize=(5, 4))

ax.plot(loss_list)
ax.set_xlabel('epoch')
ax.set_title('loss vs epoch')             

In [None]:
            
with torch.no_grad():
    # generate a long trajectory 
    X = sde.prior(5000)
    
    trajectory = sde.backward_sampling(X, model, N=1000)

    
print ("Number of states:", trajectory.shape)


In [None]:
fig,ax = plt.subplots(1,2, figsize=(12, 5))

i=0
for idx in index_list:
    t = idx * T / 1000
    sns.kdeplot(XT[idx,:, 0], ax=ax[0], label='t=%.2f' % t, bw_adjust=0.2, linestyle="--", c=color_list[i])
    sns.kdeplot(trajectory[1000-idx,:, 0], ax=ax[0], bw_adjust=0.2, linestyle="-", c=color_list[i])
    i += 1
ax[0].set_xlim(-2, 2)
plt.legend()

sns.kdeplot(XT[0,:, 0], ax=ax[1], linestyle="--", bw_adjust=0.2, c='b')
sns.kdeplot(trajectory[-1,:, 0], ax=ax[1], linestyle="-", bw_adjust=0.2, c='b')
ax[0].set_xlim(-2, 2)
