In [None]:
#! -*-coding:utf-8 -*-

import os
import torch
import time
import numpy as np
from tqdm.auto import trange

from sde import SDE
from utils import visualize_diffusion_process_2d, visualize_line
from utils import weight_init
from data import SwissRoll
from evaluate import evaluate_2d

torch.set_default_dtype(torch.float64)


In [None]:
## hyper-parameters
T = 1.0                     # the terminal time
sample_size = 10000         # the number of random samples
hdim = 50                   # the width of neural network
steps = 20000               # training steps
batch_size = 400            # size of mini-batch for training
N = 5                       # control the number of visualizations

## define data
a = 2.0                     # control the scale of swissroll

true_data = SwissRoll(noise=0.0, a=a)

x0, x1 = -1.2 * a, 1.2 * a      # the range of x coordinate to visualize and compute divergences
y0, y1 = -1.2 * a, 1.2 * a      # the range of y coordinate to visualize and compute divergences

data_dim = true_data.data_dim   # dimension of swissroll (=2)

## define sde model
beta0 = 0.1                     # beta at t = 0
beta1 = 20.0                    # beta at t = T

model = SDE(T=T, beta0=beta0, beta1=beta1, sde_type="vp", beta_type="linear", data_dim=data_dim, hidden_dim=hdim)
model.apply(weight_init)

## create directories for saving results
model_dir = f"sde_swissroll_2D_T-{T:.2f}_trainsteps-{steps:d}_bs-{batch_size:d}_hdim-{hdim:d}"
output_dir = f"outputs/{model_dir}"
ckpt_dir = f"ckpts/{model_dir}"

if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir, exist_ok=True)

## Visualize forward process

In [None]:

## visualize forward process
t_unit = T / N
t_schedule = np.arange(N + 1) * t_unit

# forward process
x_0 = true_data.sample(sample_size)
x_0 = torch.from_numpy(x_0).view(-1, data_dim)

x_t = list()

for t in t_schedule:

    x_t.append(model.forward_sde(x_0, t, to_numpy=True))

x_t = np.stack(x_t, axis=0)

visualize_diffusion_process_2d(xs=x_t, titles=[f"t={t:.2f}" for t in t_schedule], savename=None)


## Train the score matching model

In [None]:
# If model exists, load it
if os.path.exists(os.path.join(ckpt_dir, "sde.pth")):
    model.load_state_dict(torch.load(os.path.join(ckpt_dir, "sde.pth")))
    print(f"model loaded from {os.path.join(ckpt_dir, 'sde.pth')}")

# otherwise, train it
else:
    print(f"begin to train")
    model.estimate_score(data_iter=true_data.data_iter(batch_size=batch_size, maxiter=steps), steps=steps, lr=1e-2)
    print(f"finish training")

    torch.save(model.state_dict(), os.path.join(ckpt_dir, "sde.pth"))
    print(f"model saved to {os.path.join(ckpt_dir, 'sde.pth')}")

## Visualize the reverse process

In [None]:
reverse_N = 20000
h_alphas = [0.0, 1.0, 2.0, 3.0, 4.0]

x_T = torch.randn(sample_size, data_dim)

# reverse process

tilde_x_t = list()

for h_alpha in h_alphas:
        
    tic = time.time()
    tilde_x_t.append(model.sample(x_t=x_T, T=0.0, N=reverse_N, to_numpy=True, sf_alpha=h_alpha))
    toc = time.time()

    print(f"sampling with alpha={h_alpha:.2f} done, cost {toc - tic:.2f}s")

tilde_x_t = np.stack(tilde_x_t, axis=0)

visualize_diffusion_process_2d(xs=tilde_x_t, titles=[f"h={h:.2f}" for h in h_alphas], savename=None)

## Evaluate

In [7]:
data = list()

for h in np.linspace(start=0.0, stop=h_alphas[-1], num=21, endpoint=True)[::-1]:
# for sa in [sf_alpha]:

    x_0 = true_data.sample(sample_size)
    x_0 = torch.from_numpy(x_0).view(-1, data_dim)
    
    x_0_gen = model.sample(x_t=x_T, T=0, N=reverse_N, to_numpy=True, sf_alpha=h)

    js, kl, wd = evaluate_2d(true_data=x_0.detach().numpy(), fake_data=x_0_gen, x0=x0, x1=x1, y0=y0, y1=y1)

    print(f"h={h:.2f}: js={js:.4f}, kl={kl:.4f}, wd={wd:.4f}")

    data.append([h, js, kl, wd])

data = np.array(data)

visualize_line(data=data[:, 1], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with JS divergence", savename=None)
visualize_line(data=data[:, 2], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with KL divergence", savename=None)
visualize_line(data=data[:, 3], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with Wasserstein distance", savename=None)

