In [None]:
#! -*-coding:utf-8 -*-

import os
import torch
import time
import numpy as np

from lib.sde import SDE
from lib.utils import visualize_diffusion_process_1d, visualize_diffusion_process_2d, visualize_diffusion_process_2d_marginal, visualize_line
from lib.utils import weight_init
from lib.data import GMM
from lib.evaluate import evaluate_1d, evaluate_2d

torch.set_default_dtype(torch.float64)


In [None]:
## hyper-parameters
T = 4.0                                     # the terminal time
sample_size = 10000                         # the number of random samples
N = 5                                       # control the number of visualizations
h_alphas = [0.0, 1.0, 2.0, 3.0, 4.0]        # candidates for h


# GMM 1D

In [None]:
data_type = "gmm1d"
reverse_N = 40000                           # number of discretization steps

## define data

mus = np.array([-1.0, 1.0]).reshape(-1, 1)
sigms = np.array([0.1, 0.1]).reshape(-1, 1)
ps = np.array([0.5, 0.5])

true_data = GMM(mus=mus, sigmas=sigms, ps=ps, data_dim=1)

x0, x1 = -3.0, 3.0         # the range of x coordinate to visualize and compute divergences

data_dim = true_data.data_dim   # dimension of swissroll (=2)

## define sde model
beta0 = 1.0                     # beta at t = 0
beta1 = 1.0                    # beta at t = T

model = SDE(T=T, beta0=beta0, beta1=beta1, sde_type="vp", beta_type="linear", data_dim=data_dim, hidden_dim=50)
model.apply(weight_init)

## create directories for saving results
model_dir = f"sde_{data_type}_T-{T:.2f}"
output_dir = f"outputs/{model_dir}"
ckpt_dir = f"ckpts/{model_dir}"

if not os.path.exists(output_dir):
    os.makedirs(output_dir, exist_ok=True)

if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir, exist_ok=True)

## Visualize the forward process

In [None]:
## visualize forward process
t_unit = T / N
t_schedule = np.arange(N + 1) * t_unit

# forward process
x_0 = true_data.sample(sample_size)
x_0 = torch.from_numpy(x_0).view(-1, data_dim)

x_t = list()

for t in t_schedule:

    x_t.append(model.forward_sde(x_0, t, to_numpy=True))

x_t = np.stack(x_t, axis=0)
visualize_diffusion_process_1d(xs=x_t, titles=[f"t={t:.2f}" for t in t_schedule], savename=None)


## Define the score corrupter

In [None]:
class CorruptScore(object):

    def __init__(self, T, error_type="1") -> None:
        self.T = T
        
        if error_type == "1":
            self.err_fun = self._error_1
        
        elif error_type == "2":
            self.err_fun = self._error_2
        
        elif error_type == "3":
            self.err_fun = self._error_3

        else:
            raise ValueError

    def _error_1(self, score, t, eps=1e-2):
        return (1 + eps) * score

    def _error_2(self, score, t, eps=1e-2):
        return (1 + eps * (1 + np.sin(2 * np.pi * t / self.T)) / 2) * score

    def _error_3(self, score, t, eps=1e-2):
        if t > 0.05 * self.T:
            return (1 + eps) * score
        return score
        


## Visualiza the reverse process

In [None]:
# build the score corrupter
Corrupter = CorruptScore(T=T, error_type="1")
eps = -0.20

x_T = torch.randn(sample_size, data_dim)

# run the reverse process
tilde_x_t = list()

for h_alpha in h_alphas:
        
    tic = time.time()
    tilde_x_t.append(model.sample(x_t=x_T, T=0.0, N=reverse_N, to_numpy=True, sf_alpha=h_alpha, exact_score_fn=true_data.exact_score_t, corrupter=Corrupter, eps=eps))
    toc = time.time()

    print(f"sampling with alpha={h_alpha:.2f} done, cost {toc - tic:.2f}s")

tilde_x_t = np.stack(tilde_x_t, axis=0)

# visualization
visualize_diffusion_process_1d(xs=tilde_x_t, titles=[f"h={h:.2f}" for h in h_alphas], savename=None, density_func=true_data.p_t, )

## Evaluate

In [None]:
data = list()

for h in np.linspace(start=0.0, stop=h_alphas[-1], num=21, endpoint=True)[::-1]:
# for sa in [sf_alpha]:

    x_0 = true_data.sample(sample_size)
    x_0 = torch.from_numpy(x_0).view(-1, data_dim)
    
    x_0_gen = model.sample(x_t=x_T, T=0.0, N=reverse_N, to_numpy=True, sf_alpha=h, exact_score_fn=true_data.exact_score_t, corrupter=Corrupter, eps=eps)

    js, kl, wd = evaluate_1d(true_data=x_0.detach().numpy(), fake_data=x_0_gen, x0=x0, x1=x1)

    print(f"h={h:.2f}: js={js:.4f}, kl={kl:.4f}, wd={wd:.4f}")

    data.append([h, js, kl, wd])

data = np.array(data)

visualize_line(data=data[:, 1], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with JS divergence", savename=None)
visualize_line(data=data[:, 2], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with KL divergence", savename=None)
visualize_line(data=data[:, 3], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with Wasserstein distance", savename=None)



# GMM 2D

In [None]:
data_type = "gmm1d"
reverse_N = 80000           # number of discretization steps

# define data

mus = np.array([[-1.0, -1.0], [-1.0, 1.0], [1.0, 1.0], [1.0, -1.0],]).reshape(-1, 2)
sigms = np.array([0.05, 0.05, 0.05, 0.05]).reshape(-1, 1)
ps = np.array([0.25, 0.25, 0.25, 0.25])

true_data = GMM(mus=mus, sigmas=sigms, ps=ps, data_dim=2)

data_dim = true_data.data_dim

x0, x1 = -2.5, 2.5          # the range of x coordinate to visualize and compute divergences
y0, y1 = -2.5, 2.5          # the range of y coordinate to visualize and compute divergences


## Visualize the forward process

In [None]:
## visualize forward process
t_unit = T / N
t_schedule = np.arange(N + 1) * t_unit

# forward process
x_0 = true_data.sample(sample_size)
x_0 = torch.from_numpy(x_0).view(-1, data_dim)

x_t = list()

for t in t_schedule:

    x_t.append(model.forward_sde(x_0, t, to_numpy=True))

x_t = np.stack(x_t, axis=0)
visualize_diffusion_process_2d(xs=x_t, titles=[f"t={t:.2f}" for t in t_schedule], savename=None)


In [None]:
# build the score corrupter
Corrupter = CorruptScore(T=T, error_type="1")
eps = -0.20

x_T = torch.randn(sample_size, data_dim)

# run the reverse process
tilde_x_t = list()

for h_alpha in h_alphas:
        
    tic = time.time()
    tilde_x_t.append(model.sample(x_t=x_T, T=0.0, N=reverse_N, to_numpy=True, sf_alpha=h_alpha, exact_score_fn=true_data.exact_score_t, corrupter=Corrupter, eps=eps))
    toc = time.time()

    print(f"sampling with alpha={h_alpha:.2f} done, cost {toc - tic:.2f}s")

tilde_x_t = np.stack(tilde_x_t, axis=0)


In [None]:
# visualization of the emprical distribution
visualize_diffusion_process_2d(xs=tilde_x_t, titles=[f"h={h:.2f}" for h in h_alphas], savename=None)

In [None]:
# visualization of the marginal distribution
visualize_diffusion_process_2d_marginal(x_t=tilde_x_t, titles=[f"h={h:.2f}" for h in h_alphas], density_func=true_data.marginal, title=None, savename=None)

In [None]:
data = list()

for h in np.linspace(start=0.0, stop=h_alphas[-1], num=21, endpoint=True)[::-1]:
# for sa in [sf_alpha]:

    x_0 = true_data.sample(sample_size)
    x_0 = torch.from_numpy(x_0).view(-1, data_dim)
    
    x_0_gen = model.sample(x_t=x_T, T=0.0, N=reverse_N, to_numpy=True, sf_alpha=h, exact_score_fn=true_data.exact_score_t, corrupter=Corrupter, eps=eps)

    js, kl, wd = evaluate_2d(true_data=x_0.detach().numpy(), fake_data=x_0_gen, x0=x0, x1=x1, y0=y0, y1=y1)

    print(f"h={h:.2f}: js={js:.4f}, kl={kl:.4f}, wd={wd:.4f}")

    data.append([h, js, kl, wd])

data = np.array(data)

visualize_line(data=data[:, 1], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with JS divergence", savename=None)
visualize_line(data=data[:, 2], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with KL divergence", savename=None)
visualize_line(data=data[:, 3], yscale="log", xaxis=data[:, 0], xl=r"$\mathsf{h}$", yl="Error", title="Error measured with Wasserstein distance", savename=None)

