### The Encoder is used to:
- Add noise to the input images, if given and then obtain the latents for the diffusion model
- As such, the encoder is used **ONLY WHEN** we give an input image, i.e during training and it isn't used during the rest of the times, mainly **during inference**.

In [None]:
## The necessary imports
import math
import torch
from torch import nn
from torch.nn import functional as F

### The Encoder block is composed of:
- The Attention block
- The Residual Block

#### The Self-Attention and the Attention Blocks
- The Attention block is made up of Self-Attention block

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, n_heads, d_embed, in_proj_bias = True, out_proj_bias = True):
        """
        Param n_heads: the number of heads in the attention block
        Param d_embed: the embedding dimension of the token, i.e. the length of the vector for each token
        Param in_proj_bias
        Param out_proj_bias
        """
        super().__init__()
        self.in_proj = nn.Linear(d_embed, 3 * d_embed, bias = in_proj_bias)
        self.out_proj = nn.Linear(d_embed, d_embed, bias = out_proj_bias)
        self.n_heads = n_heads
        self.d_head = d_embed // n_heads

    def forward(self, x, causal_mask = False):
        input_shape = x.shape
        batch_size, sequence_length, d_embed = input_shape
        interim_shape = (batch_size, sequence_length, self.n_heads, self.d_head)

        q, k, v = self.in_proj(x).chunk(3, dim = -1)
        q = q.view(interim_shape).transpose(1, 2)
        k = k.view(interim_shape).transpose(1, 2)
        v = v.view(interim_shape).transpose(1, 2)

        weight = q @ k.transpose(-1, -2)
        if causal_mask:
            mask = torch.ones_like(weight, dtype = torch.bool).triu(1)
            weight.masked_fill_(mask, -torch.inf)
        weight /= math.sqrt(self.d_head)
        weight = F.softmax(weight, dim = -1)

        output = weight @ v
        output = output.transpose(1, 2)
        output = output.reshape(input_shape)
        output = self.out_proj(output)
        return output
    
class AttentionBlock(nn.Module):
    """
    Inherits from the Module class of the nn class.
    """
    def __init__(self, channels):
        super().__init__()
        self.groupnorm = nn.GroupNorm(32, channels)
        self.attention = SelfAttention(1, channels)

    def forward(self, x):
        residue = x
        x = self.groupnorm(x)
        n, c, h, w = x.shape
        x = x.view((n, c, h*w))
        x = x.transpose(-1, -2)
        x = self.attention(x)
        x = x.transpose(-1, -2)
        x = x.view(n, c, h, w)
        x += residue
        return x

#### Residual Block
- The Residual Block has skip connections to help with the problem of **Vanishing gradients** and add more information.

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.groupnorm_1 = nn.GroupNorm(32, in_channels)
        self.conv_1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, padding = 1)

        self.groupnorm_2 = nn.GroupNorm(32, out_channels)
        self.conv_2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, padding = 1)

        if in_channels == out_channels:
            self.residual_layer = nn.Identity()
        else:
            self.residual_layer = nn.Conv2d(in_channels, out_channels, kernel_size = 1, padding = 0)

    def forward(self, x):
        residue = x
        x = self.groupnorm_1(x)
        x = F.silu(x)
        x = self.conv_1(x)

        x = self.groupnorm_2(x)
        x = F.silu(x)
        x = self.conv_2(x)

        return x + self.residual_layer(residue)

### The Encoder Block

In [None]:
class Encoder(nn.Module):
    """
    This module stacks Convolutional layers and Residual blocks to convert an input image into a representation of size (8, 8)
    """
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 128, kernel_size = 3, padding = 1),
            ResidualBlock(128, 128),
            ResidualBlock(128, 128),
            nn.Conv2d(128, 128, kernel_size = 3, stride = 2, padding = 0),
            ResidualBlock(128, 256),
            ResidualBlock(256, 256),
            nn.Conv2d(256, 256, kernel_size = 3, stride = 2, padding = 0),
            ResidualBlock(256, 512),
            ResidualBlock(512, 512),
            nn.Conv2d(512, 512, kernel_size = 3, stride = 2, padding = 0),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            ResidualBlock(512, 512),
            AttentionBlock(512),
            ResidualBlock(512, 512),
            nn.GroupNorm(32, 512), # Figure out what this does
            nn.SiLU(),# Figure out what this does
            nn.Conv2d(512, 8, kernel_size = 3, padding = 1),
            nn.Conv2d(8, 8, kernel_size = 1, padding = 0)
        )

    def forward(self, x, noise):
        _, _, h, w = x.shape 
        for layer in self.layers:
            if getattr(layer, 'stride', None) == (2, 2): # Padding at downsampling should be asymmetric
                x = F.pad(x, (0, 1, 0, 1))
            x = layer(x)
        assert x.shape == torch.Size([1, 8, h//8, w//8])
        print(f"The shape of the representation after the convolutional layers and Residual blocks is {x.shape}")

        mean, log_variance = torch.chunk(x, 2, dim = 1)
        log_variance = torch.clamp(log_variance, -30, 20)
        variance = log_variance.exp()
        stdev = variance.sqrt()
        x = mean + stdev * noise
        x *= 0.18215

        return x

### The Inputs to the encoder are:
- Input images tensor of size (B, C, H, W)
- Encoder noise of shape (len(prompts), H//8, W//8
- Thus, we see that an image of size (3, 512, 512) has been reduced to a representation of (8, 64, 64).

In [None]:
## Consider an input of one prompt
H = W = 512
## Generate the noise
generator = torch.Generator(); noise_shape = (1, H//8, W//8)
encoder_noise = torch.randn(noise_shape, generator = generator)
input_images = torch.randint(0, 255, (1, 3, H, W), dtype = torch.float32)
## Encode a sample image
encoder = Encoder()
latents = encoder(input_images, encoder_noise)
assert latents.shape == torch.Size([1, 4, H//8, W//8])
assert latents.dtype == torch.float32