In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
import torchvision
from torchvision.utils import make_grid, save_image
import matplotlib.pyplot as plt
import imageio
import math
import numpy as np
from IPython.display import Image
# !pip install tqdm
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# device = 'cpu'

In [None]:
def save_and_display(grid: torch.Tensor, nrow: int, path: str) -> None:

    grid = make_grid(grid, nrow = nrow)
    save_image(grid, path)
    return grid, path

def save_gif(list_of_grids: list, path: str) -> None:

   toImg = transforms.ToPILImage()

   #  Each image is [4, c, h, w]
   images = [make_grid(x, nrow=4) for x in list_of_grids]
   gif_src = [np.array(toImg(img)) for img in images]

   imageio.mimsave(path, gif_src, format = 'GIF', fps = 10)

## Datasets

## Forward Pass

## U-Net Architecture

In [None]:
#   Sinoisudal Embedding
class TimeEmbedding(nn.Module):
    def __init__(self, embed_dim):

        super(TimeEmbedding, self).__init__()
        self.dim = embed_dim

    def forward(self, t: torch.Tensor):

        half_dim = self.dim // 2
        emb = math.log(10000) / (half_dim - 1)
        emb = torch.exp(torch.arange(half_dim, device = device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim = -1)

        # assert t.is_cuda == True, "Not match device"
        return emb

In [None]:
#  Residual Block
class ResBlock(nn.Module):

    def __init__(self, in_channel: int, out_channel: int, num_groups: int, dropout_rate: float, down_up: str):
        super(ResBlock, self).__init__()

        self.swish = nn.SiLU()
        self.conv1 = nn.Conv2d(in_channels = in_channel, out_channels = out_channel, kernel_size = (3,3), padding = 'same', stride = 1, device=device)
        self.conv2 = nn.Conv2d(in_channels = out_channel, out_channels = out_channel, kernel_size = (5,5), padding = 'same', stride = 1, device= device)
        self.res_conv = nn.Conv2d(in_channels = in_channel, out_channels=out_channel, kernel_size=(7,7), padding = 'same', stride = 1, device=device)

        self.group_norm1 = nn.GroupNorm(num_groups=num_groups, num_channels = in_channel, device = device)
        self.group_norm2 = nn.GroupNorm(num_groups=num_groups, num_channels = out_channel, device = device)
        self.mode = down_up
        if down_up == "down":
            self.time_embedding = TimeEmbedding(out_channel)
            self.time_mlp = nn.Sequential(
                nn.Linear(out_channel, out_channel, device=device),
                nn.SiLU(),
            )
    def forward(self, x: torch.Tensor, t: torch.Tensor):
        """
            X -> GroupNorm1 -> Swish -> Conv1 -> Out1
            Time embedding -> Swish -> MLP -> Out2
            (Out1 + Out2) -> GroupNorm2 -> Swish -> Dropout -> Conv2 -> Skip connection
        """
        B, C, H, W = x.shape

        out = x

        # X -> GroupNorm -> Swish -> Conv1 -> Out1
        out = self.group_norm1(out)
        out = self.swish(out)
        out = self.conv1(out)
        assert (H,W) == (out.shape[2], out.shape[3]), "Not compatible shape"

        # Time Embedding in the case of downblock
        if self.mode == "down":
            time_embed = self.time_embedding(t)
            time_embed = self.time_mlp(time_embed)
            time_embed = time_embed.view(B, time_embed.shape[1], 1, 1)
            out += time_embed
            
        #   Last
        out = self.group_norm2(out)
        out = self.swish(out)
        out = self.conv2(out)
        out += self.res_conv(x)

        # assert out.get_device() == device, "Not match devices"
        return out

In [None]:
#   Downsample Block
class DownBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(DownBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels = out_channel, kernel_size=(3,3), padding = 'same', stride = 1, device=device)
        self.activation = nn.SiLU()
        self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel, kernel_size=(4,4), padding = (1,1), stride = 2, device=device)
        self.norm = nn.BatchNorm2d(out_channel, device = device)
    
    def forward(self, x, t = None):
        B,C,H,W = x.shape    
        out = self.conv1(x)
        out = self.conv2(out)
        out = self.norm(out)
        out = self.activation(out)
        spatial_size = out.shape
        assert (spatial_size[2], spatial_size[3]) == (H//2, W//2), "Not compatible size!"
        return out

#   Upsample Block
class UpBlock(nn.Module):
    def __init__(self, in_channel, out_channel):
        super(UpBlock, self).__init__()
        self.deconv = nn.ConvTranspose2d(in_channels = out_channel, out_channels = out_channel, kernel_size=(4,4), padding =1, stride = 2, device=device)
        self.norm = nn.BatchNorm2d(out_channel, device = device)
        
    def forward(self, x, t = None):
        B,C,H,W = x.shape
        out = self.deconv(x)
        out = self.norm(out)
        # assert out.get_device() == device, "Not match devices"
        assert (out.shape[2], out.shape[3]) == (H*2, W*2), "Size is not compatible!"
        return out

In [None]:
class VisualAttention(nn.Module):
    def __init__(self, num_heads, in_channel, out_channel):
        super(VisualAttention, self).__init__()
        pass
    
    def forward(self, x):
        pass

In [None]:
class Unet(nn.Module):
    """
    Residual Block:
        4 resolutions (256, 128, 64, 32)
        4 blocks per resolutions. In and out channels will be defined as 
        256 x 256: (3,6) -> (6,6) -> (6,6) -> (6,6)
        128 x 128: (6,12) -> (12,12) -> (12,12) -> (12,12)
        64 x 64: (12,15) -> (15,15) -> (15,15) -> (15,15)
        32 x 32: (15,18) -> (18,18) -> (18, 18) -> (18,18)
    
    Downsample Block:
        256 -> 128 -> 64 -> 32
        (6,6) -> (12,12) -> (15,15) -> (18,18)
    Upsample Block:
        32 -> 64 -> 128 -> 256
        (18, 15) -> (15,12) -> (12,6) -> (6,3)
    """
    def __init__(self, resolutions:list, in_channels: list, out_channels: list, num_groups: int, image_size: tuple):
        super(Unet, self).__init__()
        
        assert len(resolutions) == len(in_channels), "Given {} resolutions in down sampling but just have {} expected in channels".format(len(resolutions), len(in_channels))
        self.down_residual = nn.ModuleList([nn.ModuleList([]) for idx in range(len(resolutions))])
        self.midsample = nn.ModuleList([])
        self.upsample = nn.ModuleList([nn.ModuleList([]) for idx in range(len(resolutions))])
        self.downsample = nn.ModuleList([])
        self.downsample_res = nn.ModuleList([])
        
        in_out = list(zip(in_channels, out_channels))
        # Setting for the down-sample and upsample blocks
        for idx, pair_res in enumerate(in_out):
            
            in_channel, out_channel = pair_res
            for resblock_idx in range(4):
                in_channel_down = in_channel if resblock_idx == 0 else out_channel
                out_channel_down = out_channel
                self.down_residual[idx].append(
                    ResBlock(in_channel_down, out_channel_down, num_groups, 0.0, down_up = "down")
                )
                
                in_channel_up = out_channel * 2 if resblock_idx == 0 else in_channel
                out_channel_up = in_channel
                self.upsample[idx].append(
                    ResBlock(in_channel_up, out_channel_up, num_groups, 0.0, down_up = "up")
                )
            self.downsample.append(
                DownBlock(out_channel_down, out_channel_down)
            )
            self.downsample_res.append(
                DownBlock(out_channel_down, out_channel_down)
            )
            self.upsample[idx].append(
                UpBlock(out_channel_up, out_channel_up)
            )
        self.upsample = self.upsample[::-1]
        
        self.midsample.append(nn.Conv2d(out_channels[-1], out_channels[-1], kernel_size=(3,3), padding = 'same', stride = 1, device = device))
        self.midsample.append(nn.Identity())
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.resolutions = resolutions
        
    def forward(self, x, t):
        
        """
        ResBlock (n blocks)
            :input: x: batch of images, t: time steps
        DownBlock:
            :input: Output of ResBlock
            256 x 256 -> 32 x 32
        MiddleBlock:
            :input: Output of Downblock -> Size / 8
            32 x 32 -> 16 x 16
        UpBlock:
            :input: Output of MiddleBlock
            16 x 16 -> 256 x 256
        """
        B,C,H,W = x.shape
        out = x
        n_resolutions = len(self.resolutions)
        resolution_down = []
        residual_down = []
        for down_idx, downblock in enumerate(self.down_residual):
            
            # Pass through the residual blocks and down sample 
            for dblock in downblock:
                out = dblock(out, t)
            out = self.downsample[down_idx](out, t)
            dB, dC, dH, dW = out.shape
            scale_down_factor = 2 ** (down_idx + 1)
            assert (dB, dC, dH, dW) == (B, self.out_channels[down_idx], H //scale_down_factor, W // scale_down_factor),\
            "Shape is not compatible. Expected outshape of {} but {}".format((B, self.out_channels[down_idx], H //scale_down_factor, W // scale_down_factor), (dB, dC, dH, dW))
            resolution_down.append(out)

        # Mid sample block: 16 x 16 -> 16 x 16
        for mid_block in self.midsample:
            out = mid_block(out)
            
        # Passing through the residual blocks and up sample
        for up_idx, upblock in enumerate(self.upsample):
            out = torch.cat((out, resolution_down.pop()), dim = 1)
            for ublock in upblock:
                out = ublock(out, t)
            uB, uC, uH, uW = out.shape
            # assert (uB, uC, uH, uW) == (B, self.in_channels[::-1][up_idx], self.resolutions[n_resolutions - up_idx - 1], self.resolutions[n_resolutions - up_idx - 1]),\
            # "Shape is not compatible. Expected outshape of {} but {}".format((B, self.in_channels[::-1][up_idx], self.resolutions[n_resolutions - up_idx - 1], self.resolutions[n_resolutions - up_idx - 1]), (uB, uC, uH, uW))
        
        oB, oC, oH, oW = out.shape
        assert (oB, oC, oH, oW) == (B, C, H, W), "Output shape is not compatible. Expect {} but {}".format((B,C,H,W), (oB, oC, oH, oW))
        return out

## Diffusion Models

In [None]:

class Diffusion(nn.Module):
    def __init__(self, beta_start, beta_end, time_steps, sampling_steps, network):
        super(Diffusion, self).__init__()
        
        self.beta = torch.linspace(beta_start, beta_end, time_steps, device=device, requires_grad=False)
        self.alpha = 1.0 - self.beta
        self.cum_prod_alpha = torch.cumprod(self.alpha, dim = 0)
        self.one_minus_cumprod = 1.0 - self.cum_prod_alpha
        self.denoise_net = network
        self.sampling_steps = sampling_steps
        self.time_steps = time_steps
        
    def _posterior_sample(self, x, t):
        batch, c, h, w = x.shape
        cumprod_t = self.cum_prod_alpha[t].view(batch, 1, 1, 1)
        one_minus_cumprod_t = self.one_minus_cumprod[t].view(batch, 1, 1, 1)

        noise = torch.randn_like(x, device = device, requires_grad=False)
        std = torch.sqrt(one_minus_cumprod_t)
        mean = torch.sqrt(cumprod_t) * x

        return mean + std*noise, noise
    
    @torch.no_grad()
    def _reverse(self, noise, t):
        
        B, C, H, W= noise.shape
        z = torch.randn_like(noise) if t >= 1 else 0

        time = torch.ones(B, dtype=torch.int64, device=device)*t

        eps_theta = self.denoise_net(noise, time)
        eps_coff = (1.0-self.alpha[t]) / ((1-self.cum_prod_alpha[t])**0.5)

        x_previous = (1.0 / (self.alpha[t] ** 0.5)) * (noise - eps_coff * eps_theta) + z * ((1-self.alpha[t])**0.5)

        return x_previous
    
    @torch.no_grad()
    def sampling(self, image_shape: list, batch: int):
        
        C,H,W = image_shape
        image = torch.randn(batch, C, H, W, device = device, requires_grad=False)
        tracks = [image]

        t = self.sampling_steps - 1
        
        while t >= 0:
            image = self._reverse(image, t) #   Sample x_{t-1} from p(x_t-1|x_t)
            tracks.append(image)
            t-=1
        
        return image, tracks
    
    def forward(self, x, t):
        out, noise = self._posterior_sample(x, t)    # Diffuse data
        out = self.denoise_net(out, t)               # Predict noise
        B,C,H,W = x.shape
        oB, oC, oH, oW = out.shape
        assert (B,C,H,W) == (oB, oC, oH, oW), "Output shape is not compatible with input shape. Expect {} but {}".format((B,C,H,W),(oB, oC, oH, oW) )
        return out, noise

In [None]:
in_c = [3,12,18]
out_c = [12,18,24]
resolutions = [32,16,8]
unet = Unet(resolutions, in_c, out_c, 3, (32,32)).to(device)


## Training

In [None]:
def visualize(model, batch, save_path):

    gen_img, tracks = model.sampling([3,32,32], batch)
    grid = torchvision.utils.make_grid(gen_img, nrow=2)
    torchvision.utils.save_image(grid, save_path[0])
    # tracks.extend([gen_img]*100)
    # save_gif(tracks, save_path[0])

In [None]:
def train(model, data_loader, epochs, lr, save_per_epochs, clip_val, loss_type, save_path = None):

    optimizer = torch.optim.Adam(model.denoise_net.parameters(), lr=lr)
    print(optimizer)
    print("Starting training...")
    
    losses = []
    if loss_type == "l1":
      loss_fn = torch.nn.L1Loss(reduction = 'mean')
    elif loss_type == "l2":
      loss_fn = torch.nn.MSELoss(reduction = 'mean')
    for epoch in range(epochs):
        
        for (idx, dataset) in tqdm(enumerate(data_loader)):

            x_0 = dataset[0].to(device)
            t = (model.time_steps - 1) * torch.rand((x_0.shape[0],), device = device) + 1
            t = t.long()
            optimizer.zero_grad()
            eps_theta, eps = model(x_0, t)

            loss = loss_fn(eps_theta, eps)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.denoise_net.parameters(), clip_val)
            optimizer.step()        
            losses.append(loss.item())

        print("Loss after {} = {}".format(epoch, loss.item()))
        if epoch % save_per_epochs == 0:
            torch.save(model.state_dict(), "/content/drive/MyDrive/Diffusion-Model/ckpts/ckpt_"+str(epoch))
            visualize(model, 4, ["/content/drive/MyDrive/Diffusion-Model/gif/diff_"+str(epoch)+".png"] )
            print('Saved!')

    print("End training!")
    plt.plot(losses)
    plt.savefig("/content/drive/MyDrive/Diffusion-Model/ckpts/loss.png")


In [None]:
time_steps = 1000
sampling_steps = 1000
beta_start = 2e-4
beta_end = 2e-1
def get_model():

    model = Diffusion(beta_start, beta_end, time_steps, sampling_steps, unet)
    return models

def get_model_pretrain(ckpt_path):
    model = Diffusion(beta_start, beta_end, time_steps, sampling_steps, unet)
    model.load_state_dict(torch.load(ckpt_path))
    return model

# ckpt_path = "/content/drive/MyDrive/Diffusion-Model/ckpts/ckpt_240"
# model = get_model_pretrain(ckpt_path)
model = get_model()

In [None]:
epochs = 500000
lr = 0.001
clip_val = 1.0
save_per_epochs = 10
loss_type = 'l1'

In [None]:
SIZE = 32
batch_size = 100
transform = transforms.Compose([
    transforms.Resize((SIZE, SIZE)),
    transforms.ToTensor(),
    transforms.Lambda( lambda t: (t * 2) - 1)
])
dataset = torchvision.datasets.CIFAR10(root = ".", transform = transform, download = True)
dataloader = torch.utils.data.DataLoader(dataset, batch_size = batch_size, shuffle = True)

In [None]:
train(model, dataloader, epochs, lr, save_per_epochs, clip_val, loss_type)