In [1]:
import torch
import torch.nn as nn
from einops import rearrange

In [8]:
X = torch.randint(3, (1, 16, 10)).float() # Patch embeddings

# Linear projection layer
linear_projection = torch.nn.Linear(10, 12)

# Project patch embeddings to hidden dimension
projected_patches = linear_projection(X)

print(projected_patches.shape)

torch.Size([1, 16, 12])


In [13]:
# Assume X and C are tensors with different shapes
X = torch.randn(1, 16, 10) # Patch embeddings
C = nn.Parameter(torch.randn(1, 5, 10)) # Class embeddings with different shape

C = C.expand(X.size(0), -1, -1)
X = torch.cat((X, C), 1)

print(X.shape)
print(C.shape)

# Add class embeddings to patches (broadcasting)
# X_new = X + C

# # Check the shape of the result
# print(X_new)
# print(X_new.shape)  # Output: torch.Size([1, 16, 256])


torch.Size([1, 21, 10])
torch.Size([1, 5, 10])


In [25]:
rand_tensor = torch.randint(5, (1, 3, 224, 224)).float()
print(rand_tensor.shape)

torch.Size([1, 3, 224, 224])


In [20]:
flat_img = nn.Flatten(
    start_dim=2, 
    end_dim=3
)
print(flat_img(rand_tensor).shape) # (3, 16*16)

torch.Size([1, 3, 50176])


In [26]:
class PatchEmbedding(nn.Module):
    """
    Turns 2D images into patches and then flattens each patch.
    """
    def __init__(self, in_channels: int, patch_size: int, embd_dim: int):
        super().__init__()
        self.patcher = nn.Conv2d(
            in_channels=in_channels, # 3
            out_channels=embd_dim, # 768, since each patch 16x16 with 3 channels. Hence, each patch should have 768 tokens.
            kernel_size=patch_size, # 16
            stride=patch_size, # 16
            padding=0, 
        )

        self.flatten = nn.Flatten(start_dim=2, end_dim=3) # (b, c, h, w) -> (b, c, h*w)
    
    def forward(self, x):
        x = self.patcher(x)
        x = self.flatten(x)
        return x

In [27]:
patcher = PatchEmbedding(
    in_channels=3,
    patch_size=16,
    embd_dim=768
)

out = patcher(rand_tensor)
print(out.shape)

torch.Size([1, 768, 14, 14])


In [2]:

#==============================================================================================#
# Patch embedding class
class PatchEmbedding(nn.Module):
    """
    Turns 2D images into patches and then flattens each patch.
    """
    def __init__(self, in_channels: int, patch_size: int, embd_dim: int):
        super().__init__()
        self.patcher = nn.Conv2d(
            in_channels=in_channels, # 3
            out_channels=embd_dim, # 768, since each patch 16x16 with 3 channels. Hence, each patch should have 768 tokens.
            kernel_size=patch_size, # 16
            stride=patch_size, # 16
        )

        self.flatten = nn.Flatten(start_dim=2, end_dim=3) # (b, c, h, w) -> (b, c, h*w)
    
    def forward(self, x):
        x = self.patcher(x)
        x = self.flatten(x)
        return x.permute(0, 2, 1) # (b, embd_dim, num_patches) -> (b, num_patches, embd_dim)
    
#================================================================================================#
# Multihead attention block
class MultiHeadAttentionBlock(nn.Module):
    """
    Creates multihead attention blocks
    """
    def __init__(
        self,
        embd_dim: int, 
        num_heads: int,
        attn_drop: float
    ):
        super().__init__()

        # Define normalization layer (mentioned in the ViT paper)
        self.layernorm = nn.LayerNorm(normalized_shape=embd_dim)

        # Define multihead attention from torch library
        self.multihead_attn = nn.MultiheadAttention(
            embed_dim=embd_dim,
            num_heads=num_heads,
            dropout=attn_drop,
            batch_first=True, # makes sure that batch dimension comes first
        )

    def forward(self, x):
        x = self.layernorm(x)
        attn_output, _ = self.multihead_attn(
            query=x,
            key=x,
            value=x, 
            need_weights=False, # do we need the weights or just the layer outputs?
        )
        return attn_output
    

In [3]:
# Define MLP block 
class MLPBlock(nn.Module):
    """
    Simple MLP block from the ViT paper
    """
    def __init__(
        self, 
        embd_dim: int,
        mlp_size: int, # 3072 (as mentioned in ViT paper)
        mlp_drop: float,
    ):
        super().__init__()
        # Normalization
        self.layernorm = nn.LayerNorm(normalized_shape=embd_dim)

        # Feed forward network 
        self.ffwd = nn.Sequential(
            nn.Linear(in_features=embd_dim, out_features=mlp_size),
            nn.GELU(),
            nn.Dropout(mlp_drop),
            nn.Linear(in_features=mlp_size, out_features=embd_dim),
            nn.Dropout(),
        )

    def forward(self, x):
        x = self.layernorm(x)
        x = self.ffwd(x)
        return x
    

In [4]:
# Define Encoder
class SegmenterEncoder(nn.Module):
    """
    Encoder block of Segmenter with residual connections. It's structure is similar to ViT.
    """
    def __init__(
        self, 
        embd_dim: int, 
        num_heads: int,
        mlp_size: int, 
        attn_drop: float, 
        mlp_drop: float
    ):
        super().__init__()
        # start with multihead self attention
        self.msa = MultiHeadAttentionBlock(
            embd_dim=embd_dim,
            num_heads=num_heads,
            attn_drop=attn_drop
        )

        self.mlp = MLPBlock(
            embd_dim=embd_dim,
            mlp_size=mlp_size,
            mlp_drop=mlp_drop
        )

    def forward(self, x):
        # add residual connections
        x += self.msa(x)
        x += self.mlp(x)
        return x

In [47]:
class MaskTransformer(nn.Module):
    """
    Transformer-based decoder + 
    """
    def __init__(
        self,
        num_class: int,
        patch_size: int,
        dims_encoder: int,
        num_layers: int,
        num_heads: int,
        dims_model: int,
        dims_ffwd: int, 
        attn_drop: float,
        mlp_drop: float 
    ):
        super().__init__()
        self.num_class = num_class
        self.patch_size = patch_size
        self.dims_encoder = dims_encoder
        self.num_layers = num_layers
        self.dims_model = dims_model
        self.dims_ffwd = dims_ffwd
        self.scale_factor = dims_encoder ** -0.5

        # Create transformer blocks
        self.blocks = nn.ModuleList(
            [SegmenterEncoder(
                embd_dim=dims_model,
                num_heads=num_heads,
                mlp_size=dims_ffwd,
                attn_drop=attn_drop,
                mlp_drop=mlp_drop   
            ) 
            for _ in range(num_layers)
            ]
        )

        # Initialise parameters
        self.cls_embd = nn.Parameter(torch.randn(1, num_class, dims_model))
        self.proj_dec = nn.Linear(dims_encoder, dims_model)
        self.proj_patch = nn.Parameter(self.scale_factor * torch.randn(dims_model, dims_model))
        self.proj_classes = nn.Parameter(self.scale_factor * torch.randn(dims_model, dims_model))
        self.decoder_norm = nn.LayerNorm(dims_model)
        self.mask_norm = nn.LayerNorm(num_class)

    def forward(self, x, img_size):
        H, W = img_size
        GS = H // self.patch_size

        # Project encoder output
        x = self.proj_dec(x)

        # Add class embeddings
        cls_embd = self.cls_embd.expand(x.size(0), -1, -1)
        x = torch.cat((x, cls_embd), 1)
        # Transformer layers
        for block in self.blocks:
            x = block(x)
        x = self.decoder_norm(x)

        patches, cls_seg_feat = x[:, : -self.num_class], x[:, -self.num_class :]

        patches = patches @ self.proj_patch
        cls_seg_feat = cls_seg_feat @ self.proj_classes

        patches = patches / patches.norm(dim=-1, keepdim=True)
        cls_seg_feat = cls_seg_feat / cls_seg_feat.norm(dim=-1, keepdim=True)

        masks = patches @ cls_seg_feat.transpose(1, 2)
        masks = self.mask_norm(masks)

        masks = rearrange(masks, "b (h w) n -> b n h w", h=int(GS)) # (1, num_patches, num_class) -> (1, num_class, patch_size, patch_size)
        return masks

In [48]:
import numpy as np

# Generate a random RGB image of size 256x256
image_size = (224, 224, 3)
random_image = np.random.randint(0, 256, size=image_size, dtype=np.uint8)

In [49]:
# Generate random patch embeddings with shape (1, num_patches, embedding_dim)
num_patches = 196
embedding_dim = 768
random_patches = torch.randn(1, num_patches, embedding_dim)

In [50]:
# Define model parameters
num_classes = 10
patch_size = 16
dims_encoder = embedding_dim
num_layers = 6
num_heads = 8
dims_model = 768
dims_ffwd = 2048
attn_drop = 0.1
mlp_drop = 0.1

# Create MaskTransformer instance
mask_transformer = MaskTransformer(
    num_class=num_classes,
    patch_size=patch_size,
    dims_encoder=dims_encoder,
    num_layers=num_layers,
    num_heads=num_heads,
    dims_model=dims_model,
    dims_ffwd=dims_ffwd,
    attn_drop=attn_drop,
    mlp_drop=mlp_drop
)

# Forward pass
masks = mask_transformer(random_patches, image_size[:2])

In [51]:
print(masks.shape)

torch.Size([1, 10, 14, 14])


In [15]:
print(output.shape) # (1, num_patches + num_class, dims_decoder)

torch.Size([1, 196, 768])


In [23]:
print(patches.shape)
print(cls_seg_feat.shape)

torch.Size([1, 196, 768])
torch.Size([1, 10, 768])
