<a href="https://colab.research.google.com/github/pooriaazami/deep_learning_class_notebooks/blob/main/21_Diffusion_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader, ConcatDataset

import torchvision as tv
import torchvision.transforms as T

from PIL import Image

import numpy as np
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm

In [2]:
START = 1e-4
END = .02
TIMESTEPS = 300

IMAGE_SIZE = 64
BATCH_SIZE = 64
EPOCHS = 5

LR = 1e-3

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
betas = torch.linspace(start=START, end=END, steps=TIMESTEPS)
alphas = 1 - betas

alpha_bars = torch.cumprod(alphas, dim=0)
alpha_bars_prev = F.pad(alpha_bars[:-1], (1, 0), value=1.0)
sqrt_one_over_alpha_bars = torch.sqrt(1. / alpha_bars)
sqrt_alpha_bars = torch.sqrt(alpha_bars)
sqrt_one_minus_alpha_bars = torch.sqrt(1 - alpha_bars)

posterior_variance = betas * (1. - alpha_bars_prev) / (1. - alpha_bars)

In [4]:
class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv_blocks = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, bias=False),
        nn.ReLU(inplace=True),
        nn.BatchNorm2d(out_channels),
    )

  def forward(self, x):
    return self.conv_blocks(x)

In [5]:
class DownBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.conv_blocks = nn.Sequential(
        nn.MaxPool2d(2),
        ConvBlock(in_channels, out_channels)
    )

  def forward(self, x):
    return self.conv_blocks(x)

In [6]:
class UpBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.up = nn.ConvTranspose2d(in_channels=in_channels, out_channels=in_channels//2, kernel_size=2, stride=2)
    self.conv_blocks = ConvBlock(in_channels, out_channels)

  def forward(self, x, residual_inputs):
    x = self.up(x)

    diff_y = residual_inputs.size()[2] - x.size()[2]
    diff_x = residual_inputs.size()[3] - x.size()[3]

    x = F.pad(x, [diff_x // 2, diff_x - diff_x // 2,
                  diff_y // 2, diff_y - diff_y // 2])

    x = torch.cat([residual_inputs, x], dim=1)
    x = self.conv_blocks(x)

    return x

In [7]:
class OutBlock(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()

    self.conv = nn.Conv2d(in_channels=in_channels, out_channels=num_classes, kernel_size=1)

  def forward(self, x):
    return self.conv(x)

In [8]:
class UNet(nn.Module):
  def __init__(self, in_channels, num_classes):
    super().__init__()

    self.input_block = ConvBlock(in_channels, 64)

    self.down_1 = DownBlock(64, 128)
    self.down_2 = DownBlock(128, 256)
    self.down_3 = DownBlock(256, 512)
    self.down_4 = DownBlock(512, 1024)

    self.up_4 = UpBlock(1024, 512)
    self.up_3 = UpBlock(512, 256)
    self.up_2 = UpBlock(256, 128)
    self.up_1 = UpBlock(128, 64)

    self.output_block = OutBlock(64, num_classes)

    self.embedding_up_1 = nn.Linear(1, 128)
    self.embedding_up_2 = nn.Linear(1, 256)
    self.embedding_up_3 = nn.Linear(1, 512)
    self.embedding_up_4 = nn.Linear(1, 1024)

  def forward(self, x, t):
    batch_size = x.size(0)
    down_cache_1 = self.input_block(x)

    down_cache_2 = self.down_1(down_cache_1)
    down_cache_3 = self.down_2(down_cache_2)
    down_cache_4 = self.down_3(down_cache_3)
    down_cache_5 = self.down_4(down_cache_4)

    t_embed = self.embedding_up_4(t).view(batch_size, -1, 1, 1)
    x = self.up_4(down_cache_5 + t_embed, down_cache_4)

    t_embed = self.embedding_up_3(t).view(batch_size, -1, 1, 1)
    x = self.up_3(x + t_embed, down_cache_3)

    t_embed = self.embedding_up_2(t).view(batch_size, -1, 1, 1)
    x = self.up_2(x + t_embed, down_cache_2)

    t_embed = self.embedding_up_1(t).view(batch_size, -1, 1, 1)
    x = self.up_1(x + t_embed, down_cache_1)

    x = self.output_block(x)

    return x

In [9]:
def get_index_for_batch(values, t, x_shape):
  batch_size = t.size(0)
  out = values.gather(-1, t.cpu())

  return out.reshape(batch_size, *((1,) * (len(x_shape) - 1))).to(DEVICE)

In [10]:
def forward_process(x_0, t):
  noise = torch.randn_like(x_0)

  sqrt_alpha_bars_for_batch = get_index_for_batch(sqrt_alpha_bars, t, x_0.shape).to(DEVICE)
  sqrt_one_minus_alpha_bars_for_batch = get_index_for_batch(sqrt_one_minus_alpha_bars, t, x_0.shape).to(DEVICE)

  z = sqrt_one_minus_alpha_bars_for_batch * noise + sqrt_alpha_bars_for_batch * x_0

  return z, noise

In [11]:
reverse_transforms = T.Compose([
    T.Lambda(lambda x: (x + 1) / 2),
    T.Lambda(lambda x: x.permute(1, 2, 0)),
    T.Lambda(lambda x: x * 255),
    T.Lambda(lambda x: x.cpu().numpy().astype(np.uint8)),
    T.ToPILImage()
])

def convert_tensor_image(image):
  if len(image.shape) == 4:
    image = image[0, :, :, :]

  image = reverse_transforms(image)
  return image

In [12]:
transforms = T.Compose([
    T.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    T.RandomHorizontalFlip(),
    T.ToTensor(),
    T.Lambda(lambda x: (x * 2) - 1)
])

train = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=True)
val = tv.datasets.CIFAR10(root='./dataset', download=True, transform=transforms, train=False)

dataset = ConcatDataset([train, val])
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./dataset/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:13<00:00, 12645277.96it/s]


Extracting ./dataset/cifar-10-python.tar.gz to ./dataset
Files already downloaded and verified


In [13]:
network = UNet(3, 3).to(DEVICE)

In [14]:
optimizer = optim.Adam(network.parameters(), lr=LR)

In [15]:
criterion = nn.MSELoss()

In [20]:
for epoch in range(1, EPOCHS + 1):
  print(f'Epoch {epoch} / {EPOCHS}')
  total_loss = .0
  for images, _ in tqdm(dataloader):
    optimizer.zero_grad()

    images = images.to(DEVICE)

    batch_size = images.size(0)
    t = torch.randint(0, TIMESTEPS, (batch_size,)).to(DEVICE).long()
    noisy_image, noise = forward_process(images, t)
    noise_preds = network(noisy_image, t.unsqueeze(-1).float())

    loss = criterion(noise_preds, noise)
    loss.backward()

    optimizer.step()

    total_loss += loss.detach().cpu().item()

  print(f'Loss: {total_loss:.2f}')

Epoch 1 / 5


  0%|          | 0/938 [00:00<?, ?it/s]

Loss: 135.68
Epoch 2 / 5


  0%|          | 0/938 [00:00<?, ?it/s]

Loss: 44.75
Epoch 3 / 5


  0%|          | 0/938 [00:00<?, ?it/s]

Loss: 39.01
Epoch 4 / 5


  0%|          | 0/938 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [64]:
@torch.no_grad()
def sample_timestep(x, t):
  betas_for_batch = get_index_for_batch(betas, t, x.shape)
  sqrt_one_minus_alpha_for_batch = get_index_for_batch(
    sqrt_one_minus_alpha_bars, t, x.shape
  )
  sqrt_one_over_alpha_bars_for_batch = get_index_for_batch(
    sqrt_one_over_alpha_bars, t, x.shape
  )

  model_mean = sqrt_one_over_alpha_bars_for_batch * (
      x - betas_for_batch * network(x, t.float()) / sqrt_one_minus_alpha_for_batch
  )

  posterior_variance_for_batch = get_index_for_batch(
      posterior_variance, t, x.shape
  )

  if t == 0:
    return model_mean
  else:
    noise = torch.randn_like(x)
    return model_mean + torch.sqrt(posterior_variance_for_batch) * noise

In [73]:
@torch.no_grad()
def sample_plot_image():
    img_size = IMAGE_SIZE
    img = torch.randn((1, 3, img_size, img_size), device=DEVICE)
    num_images = 10
    stepsize = int(TIMESTEPS / num_images)

    for i in range(0, TIMESTEPS)[::-1]:
        t = torch.full((1,), i, device=DEVICE, dtype=torch.long)
        img = sample_timestep(img, t)

    return img

In [79]:
img = sample_plot_image()

In [None]:
img = convert_tensor_image(img)
plt.imshow(img)
plt.show()

In [81]:
i = 0
img = torch.randn((1, 3, IMAGE_SIZE, IMAGE_SIZE), device=DEVICE)
t = torch.full((1,), i, device=DEVICE, dtype=torch.long)
img = sample_timestep(img, t)

In [None]:
img