In [1]:
import torch
from torch import nn
from torch.nn import functional as F
from attention import SelfAttention, CrossAttention

In [22]:
torch.__version__

'2.2.1+cpu'

In [2]:
z = torch.rand((2, 320, 64, 64))  #(batch, channels, H/8, W/8)
z.shape

torch.Size([2, 320, 64, 64])

In [3]:
context = torch.rand((2, 77, 768)) #(batch, seq_len, dim)
context.shape

torch.Size([2, 77, 768])

In [4]:
from diffusion import UNET_AttentionBlock

attention_block = UNET_AttentionBlock(8, 40)
attention_block

UNET_AttentionBlock(
  (groupnorm): GroupNorm(32, 320, eps=1e-06, affine=True)
  (conv_input): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
  (layernorm_1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (attention_1): SelfAttention(
    (in_proj): Linear(in_features=320, out_features=960, bias=False)
    (out_proj): Linear(in_features=320, out_features=320, bias=True)
  )
  (layernorm_2): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (attention_2): CrossAttention(
    (q_proj): Linear(in_features=320, out_features=320, bias=False)
    (k_proj): Linear(in_features=768, out_features=320, bias=False)
    (v_proj): Linear(in_features=768, out_features=320, bias=False)
    (out_proj): Linear(in_features=320, out_features=320, bias=True)
  )
  (layernorm_3): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
  (linear_geglu_1): Linear(in_features=320, out_features=2560, bias=True)
  (linear_geglu_2): Linear(in_features=1280, out_features=320, bias=True)
  (conv_

In [5]:
out_att = attention_block(z, context)
out_att.shape

torch.Size([2, 320, 64, 64])

In [6]:
from diffusion import UNET

unet = UNET()
unet

UNET(
  (encoders): ModuleList(
    (0): SwitchSequential(
      (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-2): 2 x SwitchSequential(
      (0): UNET_ResidualBlock(
        (groupnorm_feature): GroupNorm(32, 320, eps=1e-05, affine=True)
        (conv_feature): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (linear_time): Linear(in_features=1280, out_features=320, bias=True)
        (groupnorm_merged): GroupNorm(32, 320, eps=1e-05, affine=True)
        (conv_merged): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (residual_layer): Identity()
      )
      (1): UNET_AttentionBlock(
        (groupnorm): GroupNorm(32, 320, eps=1e-06, affine=True)
        (conv_input): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
        (layernorm_1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
        (attention_1): SelfAttention(
          (in_proj): Linear(in_features=320, out_features=960

In [7]:
unet.encoders

ModuleList(
  (0): SwitchSequential(
    (0): Conv2d(4, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1-2): 2 x SwitchSequential(
    (0): UNET_ResidualBlock(
      (groupnorm_feature): GroupNorm(32, 320, eps=1e-05, affine=True)
      (conv_feature): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (linear_time): Linear(in_features=1280, out_features=320, bias=True)
      (groupnorm_merged): GroupNorm(32, 320, eps=1e-05, affine=True)
      (conv_merged): Conv2d(320, 320, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (residual_layer): Identity()
    )
    (1): UNET_AttentionBlock(
      (groupnorm): GroupNorm(32, 320, eps=1e-06, affine=True)
      (conv_input): Conv2d(320, 320, kernel_size=(1, 1), stride=(1, 1))
      (layernorm_1): LayerNorm((320,), eps=1e-05, elementwise_affine=True)
      (attention_1): SelfAttention(
        (in_proj): Linear(in_features=320, out_features=960, bias=False)
        (out_proj): Linear(in_features=320

In [8]:
x = torch.rand((2, 4, 64, 64))  #latent (z = Conv2D of x)
x.shape

torch.Size([2, 4, 64, 64])

In [9]:
context = torch.rand((2, 77, 768))
context.shape

torch.Size([2, 77, 768])

In [10]:
time = torch.rand((1,1280))
time.shape

torch.Size([1, 1280])

In [11]:
skip_connections = []
for layers in unet.encoders:
    x = layers(x, context, time)
    skip_connections.append(x)
    print(x.shape)

x.shape

torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 32, 32])
torch.Size([2, 640, 32, 32])
torch.Size([2, 640, 32, 32])
torch.Size([2, 640, 16, 16])
torch.Size([2, 1280, 16, 16])
torch.Size([2, 1280, 16, 16])
torch.Size([2, 1280, 8, 8])
torch.Size([2, 1280, 8, 8])
torch.Size([2, 1280, 8, 8])


torch.Size([2, 1280, 8, 8])

In [12]:
for skip in skip_connections:
    print(skip.shape)

torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 64, 64])
torch.Size([2, 320, 32, 32])
torch.Size([2, 640, 32, 32])
torch.Size([2, 640, 32, 32])
torch.Size([2, 640, 16, 16])
torch.Size([2, 1280, 16, 16])
torch.Size([2, 1280, 16, 16])
torch.Size([2, 1280, 8, 8])
torch.Size([2, 1280, 8, 8])
torch.Size([2, 1280, 8, 8])


In [13]:
x = unet.bottleneck(x, context, time)
x.shape

torch.Size([2, 1280, 8, 8])

In [14]:
for layers in unet.decoders:
    # Since we always concat with the skip connection of the encoder, the number of features increases before being sent to the decoder's layer
    x = torch.cat((x, skip_connections.pop()), dim=1) 
    print("concat", x.shape)
    x = layers(x, context, time)
    print("decode", x.shape)
    print("="*40)

concat torch.Size([2, 2560, 8, 8])
decode torch.Size([2, 1280, 8, 8])
concat torch.Size([2, 2560, 8, 8])
decode torch.Size([2, 1280, 8, 8])
concat torch.Size([2, 2560, 8, 8])
decode torch.Size([2, 1280, 16, 16])
concat torch.Size([2, 2560, 16, 16])
decode torch.Size([2, 1280, 16, 16])
concat torch.Size([2, 2560, 16, 16])
decode torch.Size([2, 1280, 16, 16])
concat torch.Size([2, 1920, 16, 16])
decode torch.Size([2, 1280, 32, 32])
concat torch.Size([2, 1920, 32, 32])
decode torch.Size([2, 640, 32, 32])
concat torch.Size([2, 1280, 32, 32])
decode torch.Size([2, 640, 32, 32])
concat torch.Size([2, 960, 32, 32])
decode torch.Size([2, 640, 64, 64])
concat torch.Size([2, 960, 64, 64])
decode torch.Size([2, 320, 64, 64])
concat torch.Size([2, 640, 64, 64])
decode torch.Size([2, 320, 64, 64])
concat torch.Size([2, 640, 64, 64])
decode torch.Size([2, 320, 64, 64])


In [16]:
freqs = torch.pow(10000, -torch.arange(start=0, end=160, dtype=torch.float32) / 160) 
freqs.shape

torch.Size([160])

In [18]:
freqs[None].shape

torch.Size([1, 160])

In [20]:
timestep = 980
torch.tensor([timestep], dtype=torch.float32)[:, None].shape

torch.Size([1, 1])

In [21]:
import torch

torch.cuda.is_available()

False

In [26]:
fruits = ['apple', 'banana', 'cherry']

fruits.pop()

'cherry'