In [None]:
!nvidia-smi

In [None]:
!pip3 install torch torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import os
import numpy as np
import sklearn
import sklearn.datasets
from sklearn.utils import shuffle as util_shuffle
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from tqdm import tqdm


In [None]:
# Dataset iterator
def make_dataset(rng=None, size=200):
    if rng is None:
      rng = np.random.RandomState()

    radial_std = 0.3
    tangential_std = 0.1
    num_classes = 5
    num_per_class = size // 5
    rate = 0.25
    rads = np.linspace(0, 2 * np.pi, num_classes, endpoint=False)

    features = rng.randn(num_classes*num_per_class, 2) \
        * np.array([radial_std, tangential_std])
    features[:, 0] += 1.
    labels = np.repeat(np.arange(num_classes), num_per_class)

    angles = rads[labels] + rate * np.exp(features[:, 0])
    rotations = np.stack([np.cos(angles), -np.sin(angles), np.sin(angles), np.cos(angles)])
    rotations = np.reshape(rotations.T, (-1, 2, 2))
    condition = labels
    data = 2 * np.einsum("ti,tij->tj", features, rotations)
    data, condition = util_shuffle(data, condition)

    return data, condition

In [None]:
data, _ = make_dataset(size=10000)
dataset = torch.tensor(data).float()

plt.clf()
plt.scatter(data[:, 0], data[:, 1], alpha=0.5, color='red', edgecolor='white', s=40)
plt.show()

In [None]:
def beta_schedule(beta1, beta2, T, schedule='sigmoid'):
    if schedule == 'linear':
        betas = torch.linspace(beta1, beta2, T)
    elif schedule == "quad":
        betas = torch.linspace(beta1 ** 0.5, beta2 ** 0.5, T) ** 2
    elif schedule == "sigmoid":
        betas = torch.linspace(-6, 6, T)
        betas = torch.sigmoid(betas) * (beta2 - beta1) + beta1
    return betas

def ddpm_schedules(beta1, beta2, T, schedule='sigmoid'):
    """
    Returns pre-computed schedules for DDPM sampling, training process.
    """
    assert beta1 < beta2 < 1.0, "beta1 and beta2 must be in (0, 1)"

    beta_t = beta_schedule(beta1, beta2, T, schedule)
    # beta_t = (beta2 - beta1) * torch.arange(0, T + 1, dtype=torch.float32) / T + beta1
    sqrt_beta_t = torch.sqrt(beta_t)
    alpha_t = 1 - beta_t
    log_alpha_t = torch.log(alpha_t)
    alphabar_t = torch.cumsum(log_alpha_t, dim=0).exp()

    sqrtab = torch.sqrt(alphabar_t)
    oneover_sqrta = 1 / torch.sqrt(alpha_t)

    sqrtmab = torch.sqrt(1 - alphabar_t)
    mab_over_sqrtmab_inv = (1 - alpha_t) / sqrtmab

    return {
        "beta_t": beta_t,    # \beta_t
        "alpha_t": alpha_t,  # \alpha_t
        "oneover_sqrta": oneover_sqrta,  # 1/\sqrt{\alpha_t}
        "sqrt_beta_t": sqrt_beta_t,  # \sqrt{\beta_t}
        "alphabar_t": alphabar_t,  # \bar{\alpha_t}
        "sqrtab": sqrtab,  # \sqrt{\bar{\alpha_t}}
        "sqrtmab": sqrtmab,  # \sqrt{1-\bar{\alpha_t}}
        "mab_over_sqrtmab": mab_over_sqrtmab_inv,  # (1-\alpha_t)/\sqrt{1-\bar{\alpha_t}}
    }



In [None]:
# ddpm scheduling check

import matplotlib.pyplot as plt

n_T = 100

ddpm_scheduling_dict = ddpm_schedules(1e-5, 1e-2, n_T)

beta = ddpm_scheduling_dict['beta_t']
alpha = ddpm_scheduling_dict['alpha_t']
alpha_bar = ddpm_scheduling_dict['alphabar_t']

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(20,6))

axes[0].plot(np.arange(len(beta)), beta)
axes[0].set_xlabel('timesteps')
axes[0].set_ylabel('beta')

axes[1].plot(np.arange(len(alpha)), alpha)
axes[1].set_xlabel('timesteps')
axes[1].set_ylabel('alpha')

axes[2].plot(np.arange(len(alpha_bar)), alpha_bar)
axes[2].set_xlabel('timesteps')
axes[2].set_ylabel('alpha_bar')

plt.show()
plt.close()

In [None]:
sample_batch = dataset

forward_list = []
forward_list2 = []

x = sample_batch
for t in range(n_T):
    x = # ToDo: Write the function of x by using the forward process distribution
    if t % (n_T//5) == 0:
        forward_list.append(x.detach().cpu())

for t in range(n_T):
    if t % (n_T//5) == 0:
        x = # ToDo: Write the function of sample_batch by using the reparameterization trick
        forward_list2.append(x.detach().cpu())

plt.clf()
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(25, 4))
for i in range(5):
  axes[i].scatter(forward_list[i][:, 0].cpu(), forward_list[i][:, 1].cpu(), alpha=0.5, color='red', edgecolor='white', s=40)
plt.show()
plt.close()

plt.clf()
fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(25, 4))
for i in range(5):
  axes[i].scatter(forward_list2[i][:, 0].cpu(), forward_list2[i][:, 1].cpu(), alpha=0.5, color='red', edgecolor='white', s=40)
plt.show()
plt.close()

In [None]:
class ConditionalLinear(nn.Module):
    def __init__(self, num_in, num_out, n_steps):
        super(ConditionalLinear, self).__init__()
        self.num_out = num_out
        self.lin = nn.Linear(num_in, num_out)
        self.embed = nn.Embedding(n_steps, num_out)
        self.embed.weight.data.uniform_()

    def forward(self, x, t):
        out = self.lin(x)
        gamma = self.embed(t)
        out = gamma.view(-1, self.num_out) * out
        return out

class TimeConditionalModel(nn.Module):
    def __init__(self, n_steps):
        super(TimeConditionalModel, self).__init__()
        self.lin1 = ConditionalLinear(2, 128, n_steps)
        self.lin2 = ConditionalLinear(128, 128, n_steps)
        self.lin3 = ConditionalLinear(128, 128, n_steps)
        self.lin4 = nn.Linear(128, 2)

    def forward(self, x, t):
        x = F.softplus(self.lin1(x, t))
        x = F.softplus(self.lin2(x, t))
        x = F.softplus(self.lin3(x, t))
        return self.lin4(x)

In [None]:
class DDPM(nn.Module):
    def __init__(self, nn_model, betas, n_T, device):
        super(DDPM, self).__init__()
        self.nn_model = nn_model.to(device)

        # register_buffer allows accessing dictionary produced by ddpm_schedules
        # e.g. can access self.sqrtab later
        for k, v in ddpm_schedules(betas[0], betas[1], n_T).items():
            self.register_buffer(k, v)

        self.n_T = n_T
        self.device = device
        self.loss_mse = nn.MSELoss()

    def forward(self, x):
        """
        this method is used in training, so samples t and noise randomly
        """

        _ts = torch.randint(0, self.n_T, (x.shape[0],)).to(self.device)  # t ~ Uniform(0, n_T)
        noise = torch.randn_like(x)  # eps ~ N(0, 1)

        x_t = (
            self.sqrtab[_ts, None] * x
            + self.sqrtmab[_ts, None] * noise
        )  # This is the x_t, which is sqrt(alphabar) x_0 + sqrt(1-alphabar) * eps
        # We should predict the "error term" from this x_t. Loss is what we return.

        # return MSE between added noise, and our predicted noise
        eps = self.nn_model(x_t, _ts)
        return self.loss_mse() # ToDo: Write the loss function to train the network

    def sample(self, n_sample, size, device):
        # sampling the fake_data
        x_i = torch.randn(n_sample, *size).to(device)  # x_T ~ N(0, 1), sample initial noise
        x_i_store = [] # keep track of generated steps in case want to plot something
        print()
        for i in range(self.n_T-1, -1, -1):
            t_is = torch.tensor([i]).to(device)
            t_is = t_is.repeat(n_sample,1)

            z = torch.randn(n_sample, *size).to(device) if i > 1 else 0

            # split predictions and compute weighting
            eps = self.nn_model(x_i, t_is)
            x_i = (
                self.oneover_sqrta[i] * (x_i - eps * self.mab_over_sqrtmab[i])
                + self.sqrt_beta_t[i] * z
            )
            if i%20==0 or i==self.n_T or i<8:
                x_i_store.append(x_i.detach().cpu().numpy())

        x_i_store = np.array(x_i_store)
        return x_i, x_i_store

In [None]:
device='cuda:0'
n_epoch = 1000                             # total training epoch
batch_size = 256                          # number of data in each iteration
n_T = 200                                 # total timesteps of diffusion process
lrate = 1e-3                              # learning rate
save_dir = './toy_data_results/'
os.makedirs(save_dir, exist_ok=True)

ddpm = DDPM(nn_model=TimeConditionalModel(n_steps=n_T), betas=(1e-5, 1e-2), n_T=n_T, device=device)
ddpm.to(device)

optim = torch.optim.Adam(ddpm.parameters(), lr=lrate)


In [None]:
plt.ioff()

for ep in range(n_epoch):
    print(f'epoch {ep}')
    ddpm.train()

    permutation = torch.randperm(dataset.size()[0])
    pbar = tqdm(list(range(0, dataset.size()[0], batch_size)))
    loss_ema = None
    for i in pbar:
        indices = permutation[i:i+batch_size]
        x = dataset[indices].to(device)
        loss = ddpm(x)
        optim.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(ddpm.parameters(), 1.)
        if loss_ema is None:
            loss_ema = loss.item()
        else:
            loss_ema = 0.95 * loss_ema + 0.05 * loss.item()
        pbar.set_description(f"loss: {loss_ema:.4f}")
        optim.step()

    # for eval, save an image of currently generated samples (top rows)
    # followed by real images (bottom rows)
    ddpm.eval()
    if ep%100==0 or ep == int(n_epoch-1):
        with torch.no_grad():
            x_gen, x_gen_store = ddpm.sample(10000, (2,), device)
            plt.clf()
            plt.figure(figsize=(16, 12))
            plt.scatter(x_gen[:, 0].cpu(), x_gen[:, 1].cpu(), alpha=0.5, color='red', edgecolor='white', s=40)
            plt.savefig(save_dir + f"samples_ep{ep}.png")
            print('visualize samples at ' + save_dir + f"samples_ep{ep}.png")

            # create gif of images evolving over time, based on x_gen_store
            fig, axs = plt.subplots(nrows=1, ncols=1, sharex=True,sharey=True,figsize=(8,8))
            def animate_diff(i, x_gen_store):
                print(f'gif animating frame {i} of {x_gen_store.shape[0]}', end='\r')
                plots = []
                axs.clear()
                plots.append(axs.scatter(x_gen_store[i, :, 0], x_gen_store[i, :, 1], alpha=0.5, color='red', edgecolor='white', s=40))
                return plots
            # fig, axs = plt.subplots(sharex=True,sharey=True,figsize=(8,8))
            ani = FuncAnimation(fig, animate_diff, fargs=[x_gen_store],  interval=200, blit=False, repeat=True, frames=x_gen_store.shape[0])
            ani.save(save_dir + f"gif_ep{ep}.gif", dpi=100, writer=PillowWriter(fps=5))
            print('saved image at ' + save_dir + f"gif_ep{ep}.gif")
            plt.close('all')

    # optionally save model
    if ep == int(n_epoch-1):
        torch.save(ddpm.state_dict(), save_dir + f"model_{ep}.pth")
        print('saved model at ' + save_dir + f"model_{ep}.pth")

In [None]:
num_samples = 10000

x = torch.randn((num_samples, 2), device=device) # xT ~ N(0, I)
x_store = []
with torch.no_grad():
    for i in range(n_T-1, -1, -1):
        t_is = torch.tensor([i]).to(device)
        t_is = t_is.repeat(num_samples, 1)

        eps = ddpm.nn_model(x, t_is)
        x = # ToDo: Write the function of x by using the backward process distribution
        if i % 20 == 0 or i == n_T or i < 8:
            x_store.append(x.detach().cpu().numpy())

plt.clf()
plt.scatter(x[:, 0].cpu(), x[:, 1].cpu(), alpha=0.5, color='red', edgecolor='white', s=40)
plt.show()
plt.close()