In [1]:
import torch
import torch.nn as nn

from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize, Compose
from torch.utils.data import DataLoader

# UNet

Diffusion models often follow UNet-like architectures, which empirically provide the best results.
UNet was originally proposed as a model for image segmentation, but it also has proven to be effective in diffusion models based on denoising, due to their ability to capture different levels of local and global features of an image.

UNet-like architectures are designed according to the following principles.

## Encoder-decoder structure

The encoder captures the context and extracts high-level features from the input image, while the decoder reconstructs the segmented output by upsampling and combining the features from the encoder. This structure allows the network to learn both local and global information.

IMPORTANT: Do not get confused with variational autoencoders! In diffusion models, the encoder's output (a.k.a. latent variable) does not serve any specific purpose, except for one that will be discussed later. The encoder-decoder structure is mainly use to extract different types of features. 

### Contracting path (Encoder)

In practice, starting from an image, the encoder gradually increases the number of channels and gradually decreases its dimensionality.

This is called a <i>contracting path</i>. For example, the tensor dimensionality may undergo the following transformations, going from the input to the latent space:

Input: 1x28x28 --> Intermediate 1: 64x14x14 --> Intermediate 2: 128x7x7 --> Intermediate 3: 256x3x3 --> Latent: 256x1x1

In [2]:
# This ConvBlock is the elementary block of UNet (for both encoder and decoder)
# ConvBlock preserves dimensionality
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.GroupNorm(8, out_channels),  # works much better than BatchNormalization for this model
            nn.ReLU(),
        )

    def forward(self, x):
        return self.block(x)
    
# Elementary encoder block
class UnetDown(nn.Module):
    def __init__(self, in_channels, out_channels, downscale=2):
        super(UnetDown, self).__init__()
        self.model = nn.Sequential(
            ConvBlock(in_channels, out_channels), 
            nn.MaxPool2d(downscale)
        )

    def forward(self, x):

        return self.model(x)
    
# Encoder network
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.initial_features = nn.Sequential(
            nn.Conv2d(1, 64, 3, 1, 1),
            nn.GroupNorm(8, 64),
            nn.ReLU(),
        )
        self.down1 = UnetDown(64, 64, downscale=2)
        self.down2 = UnetDown(64, 128, downscale=2)
        self.down3 = UnetDown(128, 128, downscale=2)
        self.down4 = nn.Sequential(
            nn.AvgPool2d(3), 
            nn.ReLU())

    def forward(self, x):

        x_f = self.initial_features(x)
        d1 = self.down1(x_f)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        latent = self.down4(d3)
        return x_f, d1, d2, d3, latent

In [3]:
transform = Compose([ToTensor(), Normalize((0.5,), (0.5,))])
dataset = MNIST("./data", train=True, download=True, transform=transform,
)

dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

x, _ = next(iter(dataloader))
x = x.view(-1, 1, 28, 28)

In [4]:
encoder = Encoder()

x_f, d1, d2, d3, latent = encoder(x)

print("Input shape:", x.shape)
print("Down 1 shape:", d1.shape)
print("Down 2 shape:", d2.shape)
print("Down 3 shape:", d3.shape)
print("Latent shape:", latent.shape)

Input shape: torch.Size([32, 1, 28, 28])
Down 1 shape: torch.Size([32, 64, 14, 14])
Down 2 shape: torch.Size([32, 128, 7, 7])
Down 3 shape: torch.Size([32, 128, 3, 3])
Latent shape: torch.Size([32, 128, 1, 1])


## Expanding path and skip connections (decoder)

In denoising model, the decoder still has a "reverse" structure compared to the encoder. However, its purpose is to predict the next reconstruction step, rather than the entire image. Additionally, in the case of UNet, the encoder and the decoder are not separated, and we can take advantage of that.

In particular, we feed intermediate output of the encoder to the decoder layers. These are called "skip connections", since the intermediate outputs 'skip' part of the model.

    Initial Layer ----------------------- Tensor: x_f ---------------------------> Out Layer
            |                                                                          Ʌ
            V                                                                          |
        Enc Layer 1 --------------------- Tensor: d1 ------------------------> Dec Layer 4
                |                                                                Ʌ
                V                                                                |
            Enc Layer 2 ----------------- Tensor: d2 ------------------> Dec Layer 3
                    |                                                        Ʌ
                    V                                                        |
                Enc Layer 3 ------------- Tensor: d3 --------------> Dec Layer 2
                        |                                               Ʌ
                        V                                               |
                    Enc Layer 4 --------- Tensor: latent ----------> Dec Layer 1


In [5]:
# UnetUp aims at reverting UnetDown, but also uses skip connections
# In some cases, output padding is needed (e.g., on the second layer,
# with upscale 2 we get Cx3x3--> C'x6x6, so we need extra_dim 1 to get C'x7x7)
class UnetUp(nn.Module):
    def __init__(self, in_channels, out_channels, upscale=2, extra_dim=0):
        super(UnetUp, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, upscale, upscale, output_padding=extra_dim),
            ConvBlock(out_channels, out_channels),
        )

    def forward(self, x, skip):
        x = torch.cat((x, skip), 1)
        x = self.model(x)

        return x
    
# Decoder network
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        # first decoder layer does not have any skip connections
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(128, 128, 3, 3),
            nn.GroupNorm(8, 128),
            nn.ReLU(),
            
        )
        
        self.up2 = UnetUp(256, 128, upscale=2, extra_dim=1) 
        self.up3 = UnetUp(256, 64, upscale=2)
        self.up4 = UnetUp(128, 64, upscale=2)
        self.out = nn.Conv2d(128, 1, 3, 1, 1)

    def forward(self, latent, d3, d2, d1, x_f):

        u1 = self.up1(latent)
        u2 = self.up2(u1, d3)
        u3 = self.up3(u2, d2)
        u4 = self.up4(u3, d1)
        eps_hat = self.out(torch.cat([u4, x_f], dim=1))
        return u1, u2, u3, u4, eps_hat
    

In [6]:
decoder = Decoder()

u1, u2, u3, u4, eps_hat = decoder(latent, d3, d2, d1, x_f)

print("Latent shape:", latent.shape)
print("Up 1 shape:", u1.shape)
print("Up 2 shape:", u2.shape)
print("Up 3 shape:", u3.shape)
print("Up 4 shape:", u4.shape)
print("Output shape:", eps_hat.shape)

Latent shape: torch.Size([32, 128, 1, 1])
Up 1 shape: torch.Size([32, 128, 3, 3])
Up 2 shape: torch.Size([32, 128, 7, 7])
Up 3 shape: torch.Size([32, 64, 14, 14])
Up 4 shape: torch.Size([32, 64, 28, 28])
Output shape: torch.Size([32, 1, 28, 28])


In [7]:
## clear variables to limit memory consumption
del u1, u2, u3, eps_hat 

## Encoding temporal information

In denoising-based diffusion models, the generation occurs through multiple denoising steps. To enhance the quality of generation, the model needs to be informed about the current processing step $t$.

This can be done by encoding the information into additional 1-dimensional channels that get summed to the latent variable. These are called <i>time embeddings</i> and, clearly, should have the same dimensionality as the latent variable. Before being processed by the time embedding, $t$ should be normalized w.r.t. the max number $T$ of steps.

There is also the possibility to concatenate time embeddings as extra channels, but for MNIST summing them is perfectly fine.

In [8]:
# Time embedding layer
class TimeEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super(TimeEmbedding, self).__init__()

        self.embedding_dim = embedding_dim
        self.lin1 = nn.Linear(1, 32, bias=False)
        self.lin2 = nn.Linear(32, embedding_dim)

    def forward(self, ts):
        ts = ts.view(-1, 1)
        temb = torch.sin(self.lin1(ts))  # sine activation is common to encode time information
        temb = self.lin2(temb)
        temb = temb.view(-1, self.embedding_dim, 1, 1)  # add dummy channels to sum with z
        return temb
    

In [9]:
time_embedding_layer  = TimeEmbedding(128)

n_T = 1000  # number of time steps used for reconstruction 
timesteps = torch.randint(1, n_T + 1, (x.shape[0],))
t = timesteps/n_T

temb = time_embedding_layer(t)

latent_temb = latent + temb

u1, u2, u3, u4, eps_hat = decoder(latent_temb, d3, d2, d1, x_f)

In [10]:
## clear variables to limit memory consumption
del u1, u2, u3, u4, eps_hat 

## UNet model

In [11]:
# Overall UNet model
class Unet(nn.Module):

    def __init__(self):
        super(Unet, self).__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()
        self.time_embedding_layer = TimeEmbedding(128)
    
    def forward(self, x, t):

        x_f, d1, d2, d3, latent = self.encoder(x)
        temb = self.time_embedding_layer(t)
        latent_temb = latent + temb
        _, _, _, _, eps_hat = self.decoder(latent_temb, d3, d2, d1, x_f)  # intermediate decoder steps are not needed

        return eps_hat

In [12]:
unet = Unet()

eps_hat = unet(x, t)