In [5]:
!pip install labml_nn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [6]:
import math
from typing import Optional, Tuple, Union, List
import torch
from torch import nn
from labml_helpers.module import Module

In [7]:
class Swish(Module):
  def forward(self, x):#x⋅σ(x)
    return x * torch.sigmoid(x)

In [8]:
class TimeEmbedding(nn.Module):

  def __init__(self, n_channels):
    super().__init__()
    self.n_channels = n_channels
    self.lin1 = nn.Linear(self.n_channels // 4, self.n_channels)
    self.act = Swish()
    self.lin2 = nn.Linear(self.n_channels, self.n_channels)

  def forward(self, t: torch.Tensor):
    half_dim = self.n_channels // 8
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, device = t.device) * -emb)
    emb = t[:, None] * emb[None, :]
    print(t[:, None])
    print(emb[None, :])
    emb = torch.cat((emb.sin(), emb.cos()), dim = 1)
    emb = self.act(self.lin1(emb))
    emb = self.lin2(emb)
    return emb

In [10]:
class ResidualBlock(Module): # two convolution layers with group normalization
  def __init__(self, in_channels, out_channels, time_channels, n_groups = 32, dropout = 0.1): # n_groups = the number of groups for group normalization
    super().__init__()
    self.norm1 = nn.GroupNorm(n_groups, in_channels)
    self.activation1 = Swish()
    self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = (3, 3), padding=(1, 1))
    self.norm1 = nn.GroupNorm(n_groups, out_channels)
    self.activation2 = Swish()
    self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size = (3, 3), padding=(1, 1))
    if in_channels != out_channels:
      self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 1))
    else:
      self.shortcut = nn.Identity()
    self.time_emb = nn.Linear(time_channels, out_channels) 
    self.time_activation = Swish()
    self.dropout = nn.Dropout(dropout)

  def forward(self, x, t):
    x_copy = x.copy()
    x = self.norm1(x)
    x = self.activation1(x)
    x = self.conv1(x)
    x += self.time_emb(self.time_act(t))[:, :, None, None]
    print(self.time_emb(self.time_act(t))[:, :, None, None])
    x = self.norm2(x)
    x = self.activation2(x)
    x = self.conv1(x)
    print(x)
    print(x + self.shortcut(x))
    return x + self.shortcut(x)