In [6]:
import torch
import torch.nn.functional as F
from torch import nn

In [7]:
class SpatialTransformer(nn.Module):
  def __init__(self, channels, n_heads, n_layers, d_cond):
    super().__init__()
    self.norm = torch.nn.GroupNorm(num_groups = 32, num_channels = channels, eps = 1e-5, affine = True)
    self.proj_in = nn.Conv2d(channels, channels, kernel_size = 1, stride = 1, padding = 0)
    self.transformer_blocks = nn.ModuleList(
        [BasicTransformerBlock(channels, n_heads, channels // n_heads, d_cond = d_cond) for i in range(n_layers)]
    )
    self.proj_out = nn.Conv2d(channels, channels, kernel_size = 1, stride = 1, padding = 0)
  
  def forward(self, x, cond):
    batch_size, channels, height, width = x.shape
    x_residual = x
    x = self.norm(x)
    x = self.proj_in(x)
    x = x.permute(0, 2, 3, 1).view(batch_size, height * width, channels)
    for block in self.transformer_blocks:
      x = block(x, cond)
    x = x.view(batch_size, height, width, channels).permute(0, 3, 1, 2)
    x = self.proj_out(x)
    return x + x_residual

In [8]:
class BasicTransformerBlock(nn.Module):
  def __init__(self, d_model, n_heads, d_head, d_cond):
    super().__init__()
    self.attention1 = CrossAttention(d_model, d_model, n_heads, d_head)
    self.norm1 = nn.LayerNorm(d_model)
    self.attention2 = CrossAttention(d_model, d_cond, n_heads, d_head)
    self.norm2 = nn.LayerNorm(d_model)
    self.feedforward = FeedForward(d_model)
    self.norm3 = nn.LayerNorm(d_model)

  def forward(self, x, cond):
    x = self.attention1(self.norm1(x)) + x # self attention
    x = self.attention2(self.norm2(x), cond = cond) + x # cross-attention conditioning
    x = self.feedforward(self.norm3(x)) + x
    return x

In [10]:
class CrossAttention(nn.Module):
  def __init__(self, d_model, d_cond, n_heads, d_head, inplace = True):
    super().__init__()
    self.inplace = inplace
    self.n_heads = n_heads
    self.d_head = d_head
    self.scale = d_head ** -0.5
    d_attention = d_head * n_heads
    self.q_mapping = nn.Linear(d_model, d_attention, bias = False)
    self.k_mapping = nn.Linear(d_cond, d_attention, bias = False)
    self.v_mapping = nn.Linear(d_cond, d_attention, bias = False)
    self.out = nn.Sequential(nn.Linear(d_attention, d_model))
    
  def forward(self, x, cond = None):
    if cond is None:
      cond = x
    q = self.q_mapping(x)
    k = self.k_mapping(cond)
    v = self.v_mapping(cond)
    return self.normal_attention(q, k, v)

  def normal_attention(self, q, k, v):
    q = q.view(*q.shape[:2], self.n_heads, -1)
    k = k.view(*k.shape[:2], self.n_heads, -1)
    v = v.view(*v.shape[:2], self.n_heads, -1)
    attention = torch.einsum('bihd,bjhd->bhij', q, k) * self.scale
    if self.inplace:
      half = attention.shape[0] // 2
      attention[half:] = attention[half:].softmax(dim=-1)
      attention[:half] = attention[:half].softmax(dim=-1)
    else:
      attention = attention.softmax(dim=-1)
    
    out = torch.einsum('bhij,bjhd->bihd', attention, v).reshape(*out.shape[:2], -1)
    return self.out(out)

In [11]:
class FeedForward(nn.Module):
  def __init__(self, d_model, d_mult = 4):
    super().__init__()
    self.net = nn.Sequential(
        GeGLU(d_model, d_model * d_mult),
        nn.Dropout(0.),
        nn.Linear(d_model * d_mult, d_model)   
    )

  def forward(self, x):
    self.net(x)

In [12]:
class GeGLU(nn.Module):
  def __init__(self, d_in, d_out):
    super().__init__()
    self.proj = nn.Linear(d_in, d_out * 2)
  def forward(self, x):
    x, gate = self.proj(x).chunk(2, dim = -1)
    return x * F.gelu(gate)