In [1]:
%run UNet.ipynb

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [2]:
from typing import Tuple, Optional
import torch
import torch.nn.functional as F
import torch.utils.data
from torch import nn

In [3]:
def gather(consts, t):
  c = consts.gather(-1, t)
  return c.reshape(-1, 1, 1, 1)

In [4]:
class DDPM:
  def __init__(self, eps, steps, device):
    super().__init__()
    self.eps = eps
    self.steps = steps
    self.beta = torch.linspace(0.0001, 0.02, steps).to(device)
    self.alpha = 1 - self.beta
    self.alpha_bar = torch.cumprod(self.alpha, dim = 0)
    self.sigma_2 = self.beta

  def q_xt_given_x0(self, x0, t):
    mean = gather(self.alpha_bar, t) ** 0.5 * x0
    variance = 1 - gather(self.alpha_bar, t)
    return mean, variance

  def sample_from_q(self, x0, t, eps):
    if eps is None:
      eps = torch.randn_like(x0)
    mean, variance = self.q_xt_given_x0(x0, t)
    # np.random.normal(mu, sigma, n)
    return mean + (variance ** 0.5) * eps

  def sample_from_p(self, xt, t):
    eps = self.eps(xt, t)
    alpha_bar = gather(self.alpha_bar, t)
    alpha = gather(self.alpha, t)
    beta = 1 - alpha
    coefficient = beta / (1 - alpha_bar) ** 0.5
    mean = ((1 / alpha) ** 0.5) * (xt - coefficient * eps)
    variance = gather(self.sigma_2, t)
    eps_ = torch.randn(xt.shape, device = xt.device)
    return mean + (variance ** 0.5) * eps_

  def loss(self, x0, noise):
    batch_size = x0.shape[0]
    t = torch.randint(0, self.steps, (batch_size,), device = x0.device, dtype = torch.long)
    if noise is None:
      noise = torch.randn_like(x0)
    xt = self.sample_from_q(x0, t, noise)
    eps_model = self.eps(xt, t)
    return F.mse_loss(noise, eps_model)