# **0. Load Preliminary Functions**

# a. Import Libraries and Functions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision import datasets
from torchvision import models
from torchvision import transforms
from torch.autograd import Variable
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation, PillowWriter
from mpl_toolkits.mplot3d import Axes3D
from tqdm.notebook import tqdm
import numpy as np
import pickle
import itertools
import math
from typing import List

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# b. MNIST Data Loader

In [None]:
tensor_transform = transforms.Compose([
		transforms.Pad(2),
    transforms.ToTensor(),
		transforms.Normalize((0.5,), (0.5)),
])

batch_size = 256
train_dataset = datasets.MNIST(root = "./data",
									train = True,
									download = True,
									transform = tensor_transform)
test_dataset = datasets.MNIST(root = "./data",
									train = False,
									download = True,
									transform = tensor_transform)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
							   batch_size = batch_size,
								 shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
							   batch_size = batch_size,
								 shuffle = False)


# **1. Consistency Model**


In [None]:
def kerras_boundaries(sigma, eps, N, T):
    return torch.tensor(
        [
            (eps ** (1 / sigma) + i / (N - 1) * (T ** (1 / sigma) - eps ** (1 / sigma)))
            ** sigma
            for i in range(N)
        ]
    )


block = lambda ic, oc: nn.Sequential(
    nn.GroupNorm(32, num_channels=ic),
    nn.SiLU(),
    nn.Conv2d(ic, oc, 3, padding=1),
    nn.GroupNorm(32, num_channels=oc),
    nn.SiLU(),
    nn.Conv2d(oc, oc, 3, padding=1),
)


class ConsistencyModel(nn.Module):
    """
    This is ridiculous Unet structure, hey but it works!
    """

    def __init__(self, n_channel: int, eps: float = 0.002, n_feat: int = 128) -> None:
        super(ConsistencyModel, self).__init__()

        self.eps = eps

        ### UNet
        self.freqs = torch.exp(
            -math.log(10000) * torch.arange(start=0, end=n_feat, dtype=torch.float32) / n_feat
        )

        self.down = nn.Sequential(
            *[
                nn.Conv2d(n_channel, n_feat, 3, padding=1),
                block(n_feat, n_feat),
                block(n_feat, 2 * n_feat),
                block(2 * n_feat, 2 * n_feat),
            ]
        )

        self.time_downs = nn.Sequential(
            nn.Linear(2 * n_feat, n_feat),
            nn.Linear(2 * n_feat, n_feat),
            nn.Linear(2 * n_feat, 2 * n_feat),
            nn.Linear(2 * n_feat, 2 * n_feat),
        )

        self.mid = block(2 * n_feat, 2 * n_feat)

        self.up = nn.Sequential(
            *[
                block(2 * n_feat, 2 * n_feat),
                block(2 * 2 * n_feat, n_feat),
                block(n_feat, n_feat),
                nn.Conv2d(2 * n_feat, 2 * n_feat, 3, padding=1),
            ]
        )
        self.last = nn.Conv2d(2 * n_feat + n_channel, n_channel, 3, padding=1)

    def forward(self, x, t) -> torch.Tensor:
        if isinstance(t, float):
            t = (
                torch.tensor([t] * x.shape[0], dtype=torch.float32)
                .to(x.device)
                .unsqueeze(1)
            )

        # time embedding
        args = t.float() * self.freqs[None].to(t.device)
        t_emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1).to(x.device)

        x_ori = x

        # perform F(x, t)
        hs = []
        for idx, layer in enumerate(self.down):
            if idx % 2 == 1:
                x = layer(x) + x
            else:
                x = layer(x)
                x = F.interpolate(x, scale_factor=0.5)
                hs.append(x)

            x = x + self.time_downs[idx](t_emb)[:, :, None, None]

        x = self.mid(x)

        for idx, layer in enumerate(self.up):
            if idx % 2 == 0:
                x = layer(x) + x
            else:
                x = torch.cat([x, hs.pop()], dim=1)
                x = F.interpolate(x, scale_factor=2, mode="nearest")
                x = layer(x)

        x = self.last(torch.cat([x, x_ori], dim=1))

        ##################
        ### Problem 1 (a)
        ##################

        return c_skip_t[:, :, None, None] * x_ori + c_out_t[:, :, None, None] * x

    def loss(self, x, z, t1, t2, ema_model):
        ##################
        ### Problem 1 (c)
        ##################
        return 0.0

    @torch.no_grad()
    def sample(self, x, ts: List[float]):
        ##################
        ### Problem 1 (b)
        ##################

        return x

# c. Training Functions

In [None]:
def train(
    n_epoch: int = 100,
    n_channels=1,
    n_feat = 256
):
    dataloader = train_loader
    model = ConsistencyModel(n_channels, n_feat=n_feat)
    model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=1e-4)

    # Define \theta_{-}, which is EMA of the params
    ema_model = ConsistencyModel(n_channels, n_feat=n_feat)
    ema_model.to(device)
    ema_model.load_state_dict(model.state_dict())

    for epoch in range(1, n_epoch):
        N = math.ceil(math.sqrt((epoch * (150**2 - 4) / n_epoch) + 4) - 1) + 1
        boundaries = kerras_boundaries(7.0, 0.002, N, 80.0).to(device)

        pbar = tqdm(dataloader)
        loss_ema = None
        model.train()
        for x, _ in pbar:
            optim.zero_grad()
            x = x.to(device)

            z = torch.randn_like(x)
            t = torch.randint(0, N - 1, (x.shape[0], 1), device=device)
            t_0 = boundaries[t]
            t_1 = boundaries[t + 1]

            loss = model.loss(x, z, t_0, t_1, ema_model=ema_model)

            loss.backward()
            if loss_ema is None:
                loss_ema = loss.item()
            else:
                loss_ema = 0.9 * loss_ema + 0.1 * loss.item()

            optim.step()
            with torch.no_grad():
                mu = math.exp(2 * math.log(0.95) / N)
                # update \theta_{-}
                for p, ema_p in zip(model.parameters(), ema_model.parameters()):
                    ema_p.mul_(mu).add_(p, alpha=1 - mu)

            pbar.set_description(f"loss: {loss_ema:.10f}, mu: {mu:.10f}")

        model.eval()
        with torch.no_grad():
            # Sample 5 Steps
            xh = model.sample(
                torch.randn_like(x).to(device=device) * 80.0,
                list(reversed([5.0, 10.0, 20.0, 40.0, 80.0])),
            )
            xh = (xh * 0.5 + 0.5).clamp(0, 1)
            grid = make_grid(xh[:81], nrow=9, padding=0)

            img = ToPILImage()(grid)
            plt.imshow(img)
            plt.show()


# e. Training


In [None]:
##################
### Problem 1 (d)
##################
n_epoch = 100
n_feat = 256

train(n_epoch=n_epoch, n_feat=n_feat)

# **2. Ablation Study**

In [None]:
##################
### Problem 2: Ablation Study
##################