In [1]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [2]:
import torch
import torch.nn as nn
from models.vit import Transformer_Layer, Patch_Embedding

class ViT_Encoder(nn.Module):
    """
    ViT Encoder for VAE
    """
    def __init__(self, img_size, patch_size, in_channels, latent_dim,
                 embed_dim, depth, num_heads, mlp_dim, dropout):
        super().__init__()
        self.patch_embed = Patch_Embedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches
        # Learnable positional embeddings for all patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
        self.pos_drop = nn.Dropout(dropout)
        # Stack transformer encoder layers
        self.transformer_layers = nn.ModuleList([
            Transformer_Layer(embed_dim, num_heads, mlp_dim, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(embed_dim)
        # VAE specific: project to latent space
        self.ll_mu = nn.Linear(embed_dim * num_patches, latent_dim)
        self.ll_var = nn.Linear(embed_dim * num_patches, latent_dim)
        self._init_weights()
    def _init_weights(self):
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    def forward(self, x):
        # x: [bs, in_channels, img_size, img_size]
        bs = x.shape[0]
        x = self.patch_embed(x)  # [bs, num_patches, embed_dim]
        # Add positional embeddings and apply dropout
        x = x + self.pos_embed
        x = self.pos_drop(x)
        # Transformer expects shape [seq_length, batch_size, embed_dim]
        x = x.transpose(0, 1)
        for block in self.transformer_layers:
            x = block(x)
        x = self.norm(x)
        # Reshape for the VAE latent space projection
        x = x.transpose(0, 1)  # Back to [bs, num_patches, embed_dim]
        x = x.reshape(bs, -1)  # [bs, num_patches * embed_dim]
        # Get latent space parameters
        mu = self.ll_mu(x)
        log_var = self.ll_var(x)
        return mu, log_var

In [None]:
encoder = ViT_Encoder(
    img_size=256,
    patch_size=16,
    in_channels=3,
    latent_dim=256,
    embed_dim=512,
    depth=6,
    num_heads=8,
    mlp_dim=512*4,
    dropout=0.1
)

print("Number of parameters in ViT Encoder:", count_parameters(encoder))

Number of parameters in ViT Encoder: 61373312


In [4]:
class UpBlock(nn.Module):
    """
    Upsampling block for UNet decoder
    """
    def __init__(self, in_channels, out_channels, dropout=0.1):
        super().__init__()
        # Upsampling followed by a convolution (transposed conv alternative)
        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        # Double convolution block
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Dropout2d(dropout),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.GELU(),
            nn.Dropout2d(dropout)
        )
    def forward(self, x):
        # Upsample
        x = self.upsample(x)
        # Apply double convolution
        x = self.double_conv(x)
        return x

class UNet_Decoder(nn.Module):
    """
    UNet Decoder for VAE
    """
    def __init__(self, img_size, out_channels, latent_dim, base_channels, dropout):
        super().__init__()
        self.img_size = img_size
        # Calculate number of spatial dimensions needed for initial feature map
        self.init_size = img_size // 16
        self.latent_channels = base_channels * 8
        # Project from latent vector to initial feature map
        self.latent_to_features = nn.Sequential(
            nn.Linear(latent_dim, self.init_size * self.init_size * self.latent_channels),
            nn.GELU()
        )
        # Upsampling blocks
        # Each block doubles the spatial dimensions and halves the channels
        self.up1 = UpBlock(self.latent_channels, base_channels * 4, dropout)
        self.up2 = UpBlock(base_channels * 4, base_channels * 2, dropout)
        self.up3 = UpBlock(base_channels * 2, base_channels, dropout)
        self.up4 = UpBlock(base_channels, base_channels, dropout)
        # Final convolution to get the right number of output channels
        self.final_conv = nn.Sequential(
            nn.Conv2d(base_channels, out_channels, kernel_size=3, padding=1),
            nn.Sigmoid()  # Ensures output is in [0, 1] range
        )
    def forward(self, z):
        # z: [bs, latent_dim]
        bs = z.shape[0]
        # Project and reshape to initial feature map
        x = self.latent_to_features(z)
        x = x.view(bs, self.latent_channels, self.init_size, self.init_size)
        # Upsampling path
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        x = self.up4(x)
        # Final convolution
        x = self.final_conv(x)
        return x

In [5]:
decoder = UNet_Decoder(
    img_size=256,
    out_channels=3,
    latent_dim=512,
    base_channels=64,
    dropout=0.1
)

print("Number of parameters in UNet Decoder:", count_parameters(decoder))

Number of parameters in UNet Decoder: 69640899
