In [None]:
!pip install labml_nn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting labml_nn
  Downloading labml_nn-0.4.133-py3-none-any.whl (434 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m434.9/434.9 KB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting labml-helpers>=0.4.89
  Downloading labml_helpers-0.4.89-py3-none-any.whl (24 kB)
Collecting labml>=0.4.158
  Downloading labml-0.4.161-py3-none-any.whl (129 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m129.2/129.2 KB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Collecting fairscale
  Downloading fairscale-0.4.13.tar.gz (266 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m266.3/266.3 KB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ...

In [None]:
import math
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
from labml_helpers.module import Module

In [None]:
class Swish(Module):
  def forward(self, x):#x⋅σ(x)
    return x * torch.sigmoid(x)

In [None]:
class TimeEmbedding(nn.Module):

  def __init__(self, n_channels):
    super().__init__()
    self.n_channels = n_channels
    self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
    self.act = Swish()
    self.lin2 = nn.Linear(self.n_channels, self.n_channels)

  def forward(self, t: torch.Tensor):
    half_dim = self.n_channels // 8
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device = t.device) * -emb)
    emb = t[:, None] * emb[None, :]
    print(t[:, None])
    print(emb[None, :])
    emb = torch.cat((emb.sin(), emb.cos()), dim = 1)
    emb = self.act(self.lin1(emb))
    emb = self.lin2(emb)
    return emb

In [None]:
class ResidualBlock(Module): # two convolution layers with group normalization
  def __init__(self, in_channels, out_channels, time_channels, n_groups = 32, dropout = 0.1): # n_groups = the number of groups for group normalization
    super().__init__()
    self.norm1 = nn.GroupNorm(n_groups, in_channels)
    self.activation1 = Swish()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = (3, 3), padding=(1, 1))
    self.norm1 = nn.GroupNorm(n_groups, out_channels)
    self.activation2 = Swish()
    self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size = (3, 3), padding=(1, 1))
    if in_channels != out_channels:
      self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
    else:
      self.shortcut = nn.Identity()
    self.time_emb = nn.Linear(time_channels, out_channels)
    self.time_activation = Swish()
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, t):
    x_copy = x.copy()
    x = self.norm1(x)
    x = self.activation1(x)
    x = self.conv1(x)
    x += self.time_emb(self.time_act(t))[:, :, None, None] # ???
    print(self.time_emb(self.time_act(t))[:, :, None, None])
    x = self.norm2(x)
    x = self.activation2(x)
    x = self.conv1(x)
    print(x)
    print(x + self.shortcut(x))
    return x + self.shortcut(x)

In [None]:
class AttentionBlock(Module):
  def __init__(self, n_channels, n_heads, d_k, n_groups):
    super().__init__()
    if d_k is not None:
      d_k = n_channels
    self.norm = nn.GroupNorm(n_groups, n_channels)
    self.projection = nn.Linear(n_channels, n_heads * d_k * 3)
    self.output = nn.Linear(n_heads * d_k, n_channels)
    self.scale = d_k ** -0.5
    self.n_heads = n_heads
    self.d_k = d_k

  def forward(self, x, t = None): # t.shape: batch_size, time_channels
    batch_size, n_channels, height, width = x.shape
    x = x.view(batch_size, n_channels, -1).permute(0, 2, 1) # batch_size, seq, n_channels
    query_key_value = self.projection(x).view(batch_size, -1, self.n_heads, 3 * self.d_k) 
    query, key, value = torch.chunk(query_key_value, 3, dim = -1)
    attention = torch.einsum('bihd, bjhd -> bijh', query, key) * self.scale
    attention = attention.softmax(dim = 2)
    result = torch.einsum('bijh, bjhd -> bihd', attention, value) # multiplying by values
    result = result.view(batch_size, -1, self.n_heads * self.d_k)
    result = self.output(result)
    result += x # skip
    result = result.permute(0, 2, 1).view(batch_size, n_channels, height, width)
    return result

In [None]:
class DownBlock(Module):
  def __init__(self, in_channels, out_channels, time_channels, attention):
    super().__init__()
    self.residual = ResidualBlock(in_channels, out_channels, time_channels)
    if attention:
      self.attention = AttentionBlock(out_channels)
    else:
      self.attention = nn.Identity()
    
  def forward(self, x, t):
    x = self.residual(x, t)
    x = self.attention(x)
    return x

In [None]:
class UpBlock(Module):
  def __init__(self, in_channels, out_channels, time_channels, attention):
    super().__init__()
    self.residual = ResidualBlock(in_channels + out_channels, out_channels, time_channels)
    if attention:
      self.attention = AttentionBlock(out_channels)
    else:
      self.attention = nn.Identity()
    
  def forward(self, x, t):
    x = self.residual(x, t)
    x = self.attention(x)
    return x

In [None]:
class MiddleBlock(Module):
  def __init__(self, n_channels, time_channels):
    super().__init__()
    self.first_res = ResidualBlock(n_channels, n_channels, time_channels)
    self.attention = AttentionBlock(n_channels)
    self.second_res = ResidualBlock(n_channels, n_channels, time_channels)

  def forward(self, x: torch.Tensor, t: torch.Tensor):
    x = self.first_res(x, t)
    x = self.attention(x)
    x = self.second_res(x, t)
    return x

In [None]:
class Upsample(nn.Module):
  def __init__(self, n_channels):
    super().__init__()
    self.conv = nn.ConvTranspose2d(n_channels, n_channels, (4, 4), (2, 2), (1, 1))

  def forward(self, x, t):
    return self.conv(x)

In [None]:
class Downsample(nn.Module):
  def __init__(self, n_channels):
    super().__init__()
    self.conv = nn.ConvTranspose2d(n_channels, n_channels, (3, 3), (2, 2), (1, 1))

  def forward(self, x, t):
    return self.conv(x)

In [None]:
class UNet(Module):
  def __init__(self, image_channels = 3, n_channels = 64, channels_per_res = (1, 2, 2, 4), has_attention = (False, False, True, True), n_blocks = 2):
    super().__init__()
    n_resolutions = len(channels_per_res)
    self.image_projection = nn.Conv2d(image_channels, n_channels, kernel_size=(3, 3), padding=(1, 1))
    self.time_embedding = TimeEmbedding(n_channels * 4)
    down = []
    out_channels = in_channels = n_channels
    for i in range(n_resolutions):
      out_channels = in_channels * channels_per_res[i]
      for block in range(n_blocks):
        down.append(DownBlock(in_channels, out_channels, n_channels * 4, has_attention[i]))
        in_channels = out_channels
      if i < n_resolutions - 1:
        down.append(Downsample(in_channels))
    self.down = nn.ModuleList(down)

    self.middle = MiddleBlock(out_channels, n_channels * 4, )

    up = []
    in_channels = out_channels
    for i in reversed(range(n_resolutions)):
      out_channels = in_channels
      for block in range(n_blocks):
        up.append(UpBlock(in_channels, out_channels, n_channels * 4, has_attention[i]))
      out_channels = in_channels // channels_per_res[i]
      up.append(UpBlock(in_channels, out_channels, n_channels * 4, has_attention[i]))
      in_channels = out_channels
      if i > 0:
        up.append(Upsample(in_channels))
    self.up = nn.ModuleList(up)

