## Diffusion Module

In [None]:
## The necessary imports
import math
import torch
import numpy as np
from torch import nn
from tqdm import tqdm
from torch.nn import functional as F

In [None]:
def get_time_embedding(timestep, dtype = torch.float32):
    """
    Takes a timestep and gives the embedding
    """
    freqs = torch.pow(10000, -torch.arange(start = 0, end = 160, dtype = dtype) / 160)
    x = torch.tensor([timestep], dtype = dtype)[:, None] * freqs[None]
    return torch.cat([torch.cos(x), torch.sin(x)], dim = 1)

#### Time Embedding module for the vector in the attention mechansim

In [None]:
class TimeEmbedding(nn.Module):
    def __init__(self, n_embed):
        """
        n_embed: Embedding vector dimension
        Linear layers to add more parameters for learning.
        """
        super().__init__()
        self.linear_1 = nn.Linear(n_embed, 4*n_embed)
        self.linear_1 = nn.Linear(4*n_embed, 4*n_embed)
    
    def forward(self, x):
        x = self.linear_1(x)
        x = F.silu(x)
        x = self.linear_2(x)
        return x

In [None]:
class SelfAttention(nn.Module):
	def __init__(self, n_heads, d_embed, in_proj_bias = True, out_proj_bias = True):
		"""
		Param n_heads: the number of heads in the attention block
		Param d_embed: the embedding dimension of the token, i.e. the length of the vector for each token
		Param in_proj_bias
		Param out_proj_bias
		"""
		super().__init__()
		self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias = in_proj_bias)
		self.out_proj = nn.Linear(d_embed, d_embed, bias = out_proj_bias)
		self.n_heads = n_heads
		self.d_head = d_embed // n_heads

	def forward(self, x, causal_mask = False):
		input_shape = x.shape
		batch_size, sequence_length, d_embed = input_shape
		interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

		q, k, v = self.in_proj(x).chunk(3, dim = -1)
		q = q.view(interim_shape).transpose(1, 2)
		k = k.view(interim_shape).transpose(1, 2)
		v = v.view(interim_shape).transpose(1, 2)

		weight = q @ k.transpose(-1, -2)
		if causal_mask:
			mask = torch.ones_like(weight, dtype = torch.bool).triu(1)
			weight.masked_fill_(mask, -torch.inf)
		weight /= math.sqrt(self.d_head)
		weight = F.softmax(weight, dim = -1)

		output = weight @ v
		output = output.transpose(1, 2)
		output = output.reshape(input_shape)
		output = self.out_proj(output)
		return output

In [None]:
class CrossAttention(nn.Module):
    def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
        super().__init__()
        self.q_proj   = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
        self.k_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.v_proj   = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias=out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads
    
    def forward(self, x, y):
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, -1, self.n_heads, self.d_head)

        q = self.q_proj(x)
        k = self.k_proj(y)
        v = self.v_proj(y)

        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        weight /= math.sqrt(self.d_head)
        weight = F.softmax(weight, dim=-1)

        output = weight @ v
        output = output.transpose(1, 2).contiguous()
        output = output.view(input_shape)
        output = self.out_proj(output)
        return output

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, n_time = 1280):
        super().__init__()
        self.groupnorm_feature = nn.GroupNorm(32, in_channels)
        self.conv_feature = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)
        self.linear_time = nn.Linear(n_time, out_channels)
        
        self.groupnorm_merged = nn.GroupNorm(32, out_channels)
        self.conv_merged = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)
        
        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size = 1, padding = 0)
    
    def forward(self, feature, time):
        residue = feature
        
        feature = self.groupnorm_feature(feature)
        feature = F.silu(feature)
        feature = self.conv_feature(feature)
        
        time = F.silu(time)
        time = self.linear_time(time)
        
        merged = feature + time.unsqueeze(-1).unsqueeze(-1)
        merged = self.groupnorm_merged(merged)
        merged = F.silu(merged)
        merged = self.conv_merged(merged)
        
        return merged + self.residual_layer(residue)

In [None]:
class AttentionBlock(nn.Module):
	def __init__(self, n_heads, n_embed, d_context = 768):
		super().__init__()
		channels = n_heads * n_embed

		self.groupnorm = nn.GroupNorm(32, channels, eps = 1e-6)
		self.conv_input = nn.Conv2d(channels, channels, kernel_size = 1, padding = 0)

		self.layernorm_1 = nn.LayerNorm(channels)
		self.attention_1 = SelfAttention(n_heads, channels, in_proj_bias = False)
		self.layernorm_2 = nn.LayerNorm(channels)
		self.attention_2 = CrossAttention(n_heads, channels, d_context, in_proj_bias = False)
		self.layernorm_3 = nn.LayerNorm(channels)
		self.linear_geglu_1 = nn.Linear(channels, 4*channels*2)
		self.linear_geglu_2 = nn.Linear(4 * channels, channels)

		self.conv_output = nn.Conv2d(channels, channels, kernel_size = 1, padding = 0)

	def forward(self, x, context):
		residue_long = x

		x = self.groupnorm(x)
		x = self.conv_input(x)

		n, c, h, w = x.shape
		x = x.view((n, c, h*w)) # Transforms to size (n, c, hw)
		x = x.transpose(-1, -2) # Reorder to (n, hw, c)

		residue_short = x
		x = self.layernorm_1(x)
		x = self.attention_1(x)
		x += residue_short

		residue_short = x
		x = self.layernorm_2(x)
		x = self.attention_2(x, context)
		x += residue_short

		residue_short = x
		x = self.layernorm_3(x)
		x, gate = self.linear_geglu_1(x).chunk(2, dim = -1)
		x = x * F.gelu(gate)
		x = self.linear_geglu_2(x)
		x += residue_short

		x = x.transpose(-1, -2) # Reorder to (n, c, hw)
		x = x.view(n, c, h, w) # (Reshape to n, c, h, w)
		return self.conv_output(x)

In [None]:
class Upsample(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, kernel_size = 3, padding = 1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor = 2, mode = "nearest")
        return self.conv(x)

In [None]:
class SwitchSequential(nn.Sequential):
    def forward(self, x, context, time):
        for layer in self:
            if isinstance(layer, AttentionBlock):
                x = layer(x, context)
            elif isinstance(layer, ResidualBlock):
                x = layer(x, time)
            else:
                x = layer(time)
        return x

In [None]:
class UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoders = nn.ModuleList([
            SwitchSequential(nn.Conv2d(4, 320, kernel_size = 3, padding = 1)),
            SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
            SwitchSequential(ResidualBlock(320, 320), AttentionBlock(8, 40)),
            SwitchSequential(nn.Conv2d(320, 320, kernel_size = 3, stride = 2, padding = 1)),
            SwitchSequential(ResidualBlock(320, 640), AttentionBlock(8, 80)),
            SwitchSequential(ResidualBlock(640, 640), AttentionBlock(8, 80)),
            SwitchSequential(nn.Conv2d(640, 640, kernel_size = 3, stride = 2, padding = 1)),
            SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
            SwitchSequential(ResidualBlock(640, 1280), AttentionBlock(8, 160)),
            SwitchSequential(nn.Conv2d(1280, 1280, kernel_size = 3, stride = 2, padding = 1)),
            SwitchSequential(ResidualBlock(1280, 1280)),
            SwitchSequential(ResidualBlock(1280, 1280)),
        ])
    
        self.bottleneck = nn.ModuleList([
            SwitchSequential(ResidualBlock(1280, 1280)),
            SwitchSequential(AttentionBlock(8, 160)),
            SwitchSequential(ResidualBlock(1280, 1280)),
        ])
        
        self.decoders = nn.ModuleList([
            SwitchSequential(ResidualBlock(2560, 1280)),
            SwitchSequential(ResidualBlock(2560, 1280)),
            SwitchSequential(ResidualBlock(2560, 1280), Upsample(1280)),
            SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
            SwitchSequential(ResidualBlock(2560, 1280), AttentionBlock(8, 160)),
            SwitchSequential(ResidualBlock(1920, 1280), AttentionBlock(8, 160), Upsample(1280)),
            SwitchSequential(ResidualBlock(1920, 640), AttentionBlock(8, 80)),
            SwitchSequential(ResidualBlock(1280, 640), AttentionBlock(8, 80)),
            SwitchSequential(ResidualBlock(960, 640), AttentionBlock(8, 80), Upsample(640)),
            SwitchSequential(ResidualBlock(960, 320), AttentionBlock(8, 40)),
            SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
            SwitchSequential(ResidualBlock(640, 320), AttentionBlock(8, 40)),
        ])
    
    def forward(self, x, context, time):
        skip_connections = []
        for layers in self.encoders:
            x = layers(x, context, time)
            skip_connections.append(x)
        x = self.bottleneck(x, context, time)
        
        for layers in self.decoders:
            x = torch.cat((x, skip_connections.pop()), dim = 1)
            x = layers(x, context, time)
        return x

In [None]:
class FinalLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, in_channels)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)
    
    def forward(self, x):
        x = self.groupnorm(x)
        x = F.silu(x)
        x = self.conv(x)
        return x

In [None]:
class Diffusion(nn.Module):
    def __init__(self):
        super().__init__()
        self.time_embedding = TimeEmbedding(320)
        self.unet = UNet()
        self.final = FinalLayer(320, 4)
    
    def forward(self, latent, context, time):
        time = self.time_embedding(time)
        output = self.unet(latent, context, time)
        output = self.final(output)
        print(f"The shape of the output from the diffusion module is {output.shape}")
        return output

#### To simulate how the Diffusion module works, we need:
- Timesteps from the sampler used to generate noise
- The time_embedding from the utils
- Context from the CLIP Module, i.e. the context derived from the input text prompts

In [None]:
context = torch.randn(1, 77, 768, dtype = torch.float32) # Derived by adding context to the input text prompts
latents = torch.randn(1, 4, 64, 64) # Derived from encoding the input images
time_embedding = torch.randn(1, 320, dtype = torch.float32) # Derived from the sampler
diffusion = Diffusion()
input_latents = latents * 0.386

In [None]:
## First the TimeEmbedding Module
## One of the inputs to the UNet is the time embedding of shape (320 * 4), if 320 is the embedding dimension
linear_1 = nn.Linear(320, 320 * 4) # 320 as the embedding dimension
linear_2 = nn.Linear(320 * 4, 320 * 4)
time = linear_1(time_embedding)
assert time.shape == torch.Size([1, 1280])
time = F.silu(time)
assert time.shape == torch.Size([1, 1280])
time = linear_2(time)
assert time.shape == torch.Size([1, 1280])

##### Next comes the Time-varying UNet module
- Encoder block

In [None]:
## After the time embedding, the inputs are run through the UNet
## The encoder is composed of ResidualBlocks, AttentionBlocks and Convolution layers
x = input_latents; skip_connections = []
print(f"Prior to the UNet module, the input latents are of shape {x.shape}")
print(f"Prior to the UNet module, the context is of shape {context.shape}")
print(f"Prior to the UNet module, the time-embedding input is of shape {time.shape}")
## Let us run through the encoder block of the UNet first
# Starts with a Convolution layer
x = nn.Conv2d(4, 320, kernel_size = 3, padding = 1)(x)
assert x.shape == torch.Size([1, 320, 64, 64])
skip_connections.append(x)
# This is followed by a ResidualBlock and an AttentionBlock
x = ResidualBlock(320, 320)(x, time)
assert x.shape == torch.Size([1, 320, 64, 64])
skip_connections.append(x)
x = AttentionBlock(8, 40)(x, context)
assert x.shape == torch.Size([1, 320, 64, 64])
skip_connections.append(x)
# This is also followed by a ResidualBlock and an AttentionBlock
x = ResidualBlock(320, 320)(x, time)
assert x.shape == torch.Size([1, 320, 64, 64])
skip_connections.append(x)
x = AttentionBlock(8, 40)(x, context)
assert x.shape == torch.Size([1, 320, 64, 64])
skip_connections.append(x)
# Next comes a Convolution layer to understand the features of the input map
x = nn.Conv2d(320, 320, kernel_size = 3, stride = 2, padding = 1)(x)
assert x.shape == torch.Size([1, 320, 32, 32])
skip_connections.append(x)
x = ResidualBlock(320, 640)(x, time)
assert x.shape == torch.Size([1, 640, 32, 32])
skip_connections.append(x)
x = AttentionBlock(8, 80)(x, context)
assert x.shape == torch.Size([1, 640, 32, 32])
skip_connections.append(x)
x = nn.Conv2d(640, 640, kernel_size = 3, stride = 2, padding = 1)(x)
assert x.shape == torch.Size([1, 640, 16, 16])
skip_connections.append(x)
x = ResidualBlock(640, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 16, 16])
skip_connections.append(x)
x = AttentionBlock(8, 160)(x, context)
assert x.shape == torch.Size([1, 1280, 16, 16])
skip_connections.append(x)
x = nn.Conv2d(1280, 1280, kernel_size = 3, stride = 2, padding = 1)(x)
assert x.shape == torch.Size([1, 1280, 8, 8])
skip_connections.append(x)
x = ResidualBlock(1280, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])
skip_connections.append(x)
x = ResidualBlock(1280, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])
skip_connections.append(x)

##### Time-varying UNet
- Bottleneck

In [None]:
## Consider the layers of the bottleneck
print(f"At the start of the bottleneck, the inputs are of shape {x.shape}")
## We have two ResidualBlocks and one Attention mechanism
x = ResidualBlock(1280, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])
x = AttentionBlock(8, 160)(x, context)
assert x.shape == torch.Size([1, 1280, 8, 8])
x = ResidualBlock(1280, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])
x = AttentionBlock(8, 160)(x, context)

##### Time-varying UNet
- Decoder Block

In [None]:
print("The skip connections are as follows:")
for s_c in skip_connections:
    print(type(s_c))
    print(s_c.shape)
    print("\n")

In [None]:
## Consider the layers of the decoder block
print(f"At the start of the decoder block, the inputs are of shape {x.shape}")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 2560, 8, 8])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(2560, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 2560, 8, 8])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(2560, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 2560, 8, 8])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(2560, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 8, 8])
x = Upsample(1280)(x)
assert x.shape == torch.Size([1, 1280, 16, 16])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 2560, 16, 16])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(2560, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 16, 16])
x = AttentionBlock(8, 160)(x, context)
assert x.shape == torch.Size([1, 1280, 16, 16])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 2560, 16, 16])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(2560, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 16, 16])
x = AttentionBlock(8, 160)(x, context)
assert x.shape == torch.Size([1, 1280, 16, 16])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 1920, 16, 16])
print(f"After concatenation, the input is of shape {x.shape}")
x = ResidualBlock(1920, 1280)(x, time)
assert x.shape == torch.Size([1, 1280, 16, 16])
x = AttentionBlock(8, 160)(x, context)
assert x.shape == torch.Size([1, 1280, 16, 16])
x = Upsample(1280)(x)
assert x.shape == torch.Size([1, 1280, 32, 32])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 1920, 32, 32])
x = ResidualBlock(1920, 640)(x, time)
assert x.shape == torch.Size([1, 640, 32, 32])
x = AttentionBlock(8, 80)(x, context)
assert x.shape == torch.Size([1, 640, 32, 32])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 1280, 32, 32])
x = ResidualBlock(1280, 640)(x, time)
assert x.shape == torch.Size([1, 640, 32, 32])
x = AttentionBlock(8, 80)(x, context)
assert x.shape == torch.Size([1, 640, 32, 32])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 960, 32, 32])
x = ResidualBlock(960, 640)(x, time)
assert x.shape == torch.Size([1, 640, 32, 32])
x = AttentionBlock(8, 80)(x, context)
assert x.shape == torch.Size([1, 640, 32, 32])
x = Upsample(640)(x)
assert x.shape == torch.Size([1, 640, 64, 64])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 960, 64, 64])
x = ResidualBlock(960, 320)(x, time)
assert x.shape == torch.Size([1, 320, 64, 64])
x = AttentionBlock(8, 40)(x, context)
assert x.shape == torch.Size([1, 320, 64, 64])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 640, 64, 64])
x = ResidualBlock(640, 320)(x, time)
assert x.shape == torch.Size([1, 320, 64, 64])
x = AttentionBlock(8, 40)(x, context)
assert x.shape == torch.Size([1, 320, 64, 64])

print("\n")

print(f"There are {len(skip_connections)} skip connections left")
x0 = skip_connections.pop()
print(f"The skip connection is of shape {x0.shape}")
print(f"The input is of shape {x.shape}")
x = torch.cat((x, x0), dim = 1)
print(f"The Concatenated input is of shape {x.shape}")
assert x.shape == torch.Size([1, 640, 64, 64])
x = ResidualBlock(640, 320)(x, time)
assert x.shape == torch.Size([1, 320, 64, 64])
x = AttentionBlock(8, 40)(x, context)
assert x.shape == torch.Size([1, 320, 64, 64])

print("\n")
print(f"After the UNet, the final output is of shape {x.shape}")

In [None]:
print(f"There are {len(skip_connections)} skip connections left")

#### Final Layer
- The Final layer is a convolution conversion to a lower dimension map for the Decoder

In [None]:
x = nn.GroupNorm(32, 320)(x)
assert x.shape == torch.Size([1, 320, 64, 64])
x = F.silu(x)
assert x.shape == torch.Size([1, 320, 64, 64])
x = nn.Conv2d(320, 4, kernel_size = 3, padding = 1)(x)
assert x.shape == torch.Size([1, 4, 64, 64])