In [5]:
%run UNet.ipynb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting labml_nn
  Downloading labml_nn-0.4.133-py3-none-any.whl (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.9/434.9 KB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
Collecting labml-helpers>=0.4.89
  Downloading labml_helpers-0.4.89-py3-none-any.whl (24 kB)
Collecting fairscale
  Downloading fairscale-0.4.13.tar.gz (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 KB[0m [31m30.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 KB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m

In [9]:
from typing import Tuple, Optional, List
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn
import torchvision
from labml_helpers.device import DeviceConfigs
from labml.configs import BaseConfigs, option
from labml import lab, tracker, experiment, monit
from PIL import Image

In [7]:
def gather(consts, t):
  c = consts.gather(-1, t)
  return c.reshape(-1, 1, 1, 1)

In [8]:
class DDPM:
  def __init__(self, eps, steps, device):
    super().__init__()
    self.eps = eps
    self.steps = steps
    self.beta = torch.linspace(0.0001, 0.02, steps).to(device)
    self.alpha = 1 - self.beta
    self.alpha_bar = torch.cumprod(self.alpha, dim = 0)
    self.sigma_2 = self.beta

  def q_xt_given_x0(self, x0, t):
    mean = gather(self.alpha_bar, t) ** 0.5 * x0
    variance = 1 - gather(self.alpha_bar, t)
    return mean, variance

  def sample_from_q(self, x0, t, eps):
    if eps is None:
      eps = torch.randn_like(x0)
    mean, variance = self.q_xt_given_x0(x0, t)
    return mean + (variance ** 0.5) * eps

  def sample_from_p(self, xt, t):
    eps = self.eps(xt, t)
    alpha_bar = gather(self.alpha_bar, t)
    alpha = gather(self.alpha, t)
    beta = 1 - alpha
    coefficient = beta / (1 - alpha_bar) ** 0.5
    mean = ((1 / alpha) ** 0.5) * (xt - coefficient * eps)
    variance = gather(self.sigma_2, t)
    eps_ = torch.randn(xt.shape, device = xt.device)
    return mean + (variance ** 0.5) * eps_

  def loss(self, x0, noise):
    batch_size = x0.shape[0]
    t = torch.randint(0, self.steps, (batch_size,), device = x0.device, dtype = torch.long)
    if noise is None:
      noise = torch.randn_like(x0)
    xt = self.sample_from_q(x0, t, noise)
    eps_model = self.eps(xt, t)
    return F.mse_loss(noise, eps_model)

In [14]:
class Configs(BaseConfigs):
  device: torch.device = DeviceConfigs()
  epsilon_model: UNet
  diffusion: DDPM
  img_channels: int = 3
  img_size: int = 32
  n_channels: int = 64
  ch_mults: List[int] = [1, 2, 2, 4]
  has_attention: List[int] = [False, False, False, True]
  n_steps: int = 1000
  batch_size: int = 64
  n_samples: int = 16
  learning_rate: float = 2e-5
  epochs: int = 1000
  dataset: torch.utils.data.Dataset
  data_loader: torch.utils.data.DataLoader
  optimizer: torch.optim.Adam

  def init(self):
    self.epsilon_model = UNet(
        image_channels = self.image_channels,
        n_channels = self.n_channels,
        ch_mults = self.ch_mults,
        has_attention = self.has_attention
    ).to(self.device)

    self.diffusion = DDPM(
        eps = self.epsilon_model,
        steps = self.n_steps,
        device = self.device
    )

    data_loader = torch.utils.data.DataLoader(self.dataset, self.batch_size, shuffle = True, pin_memory = True)
    self.optimizer = torch.optim.Adam(self.epsilon_model.parameters(), lr=self.learning_rate)
    print("parameters:\n" + str(self.epsilon_model.parameters()))
    tracker.set_image("sample", True)

  def sample(self):
    with torch.no_grad():
      x = torch.randn([self.n_samples, self.img_channels, self.img_size, self.img_size], device = self.device)
      for t_ in monit.iterate('Sample', self.n_steps):
        t = self.n_steps - t_ - 1
        x = self.diffusion.sample_from_p(x, x.new_full((self.n_samples,), t, dtype=torch.long))
      tracker.save('sample', x)

  def train(self):
    for data in monit.iterate('Train', self.data_loader):
      tracker.add_global_step()
      data = data.to(self.device)
      self.optimizer.zero_grad()
      loss = self.diffusion.loss(data)
      loss.backward()
      self.optimizer.step()
      tracker.save('loss', loss)
  
  def training_loop(self):
    for _ in monit.loop(self.epochs):
      self.train()
      self.sample()
      tracker.new_line()
      experiment.save_checkpoint()

In [11]:
import helper
data_dir = '/data/celebA'
helper.download_extract('celeba', data_dir)

Downloading celeba: 1.44GB [00:19, 75.6MB/s]                            


Extracting celeba...


In [16]:
class CelebADataset(torch.utils.data.Dataset):
  def __init__(self, image_size: int):
    super().__init__()
    folder = lab.get_data_path() / 'celebA'
    self._files = [p for p in folder.glob(f'**/*.jpg')]
    self._transform = torchvision.transforms.Compose([torchvision.transforms.Resize(image_size),torchvision.transforms.ToTensor(),])

  def __len__(self):
    return len(self._files)

  def __getitem__(self, index: int):
    img = Image.open(self._files[index])
    return self._transform(img)

  @option(Configs.dataset, 'CelebA')
  def celeb_dataset(c: Configs):
    return CelebADataset(c.image_size)