In [1]:
import torch
import torch.nn as nn

In [2]:
class PatchTokenizer(nn.Module):
    def __init__(self, img_size, n_channels, patch_size, latent_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.latent_dim = latent_dim

        self.conv2d_patch_tokenizer = nn.Conv2d(in_channels=n_channels, out_channels=latent_dim, kernel_size=(patch_size, patch_size), stride=patch_size, padding=0)
        self.flatten = nn.Flatten(start_dim=2)
        
    def forward(self, img):
        patches = self.conv2d_patch_tokenizer(img) # Shape (batch_size, latent_dim, patch_size, patch_size)
        flattened_patches = self.flatten(patches) # Shape (batch_size, latent_dim, num_patches)
        
        return flattened_patches.permute(0, 2, 1) # Shape (batch_size, num_patches, latent_dim)




In [3]:
batch_size = 32
img_height = 224
img_width = 224
n_channels = 3
latent_dim = 384

img = torch.rand(batch_size, n_channels, img_height, img_width)

patch_tokenizer = PatchTokenizer(224, 3, 16, latent_dim)

patches = patch_tokenizer(img)
patches.shape

torch.Size([32, 196, 384])

In [None]:
class_token = nn.Parameter(torch.randn(batch_size, 1, latent_dim), requires_grad=True)

In [5]:
tensor_with_class_token = torch.cat((class_token, patches), dim=1)
tensor_with_class_token.shape

torch.Size([32, 197, 384])

In [None]:
learnable_pos_embedding = nn.Parameter(torch.randn(batch_size, 197, latent_dim), requires_grad=True)
tensor_with_pos_embedding = tensor_with_class_token + learnable_pos_embedding
tensor_with_pos_embedding.shape

torch.Size([32, 197, 384])

In [None]:
import math

def scaled_dot_product_attention(key: torch.Tensor, query: torch.Tensor, value: torch.Tensor, mask=None):
    d_k = key.shape(-1)
    scaled_dot_product = torch.matmul(query, key.transpose(-1, -2)) / math.sqrt(d_k)
    attention = torch.nn.functional.softmax(scaled_dot_product, dim=-1)
    values = torch.matmul(attention, values)

    return values, attention

class MultiheadSelfAttentionBlock(nn.Module):
    def __init__(self, input_dim, latent_dim, num_heads):
        super().__init__()
        self.input_dim = input_dim 
        self.latent_dim = latent_dim         
        self.num_heads = num_heads      
        self.head_dim = latent_dim // num_heads  

        self.qkv_layer = nn.Linear(input_dim, 3 * latent_dim)
        self.linear_layer = nn.Linear(latent_dim, latent_dim)

        self.layer_norm = nn.LayerNorm(normalized_shape=latent_dim)

    def forward(self, x, mask=None):
        batch_size, sequence_length, input_dim = x.size()
        qkv = self.qkv_layer(x) # Shape: (batch, seq_len, 3 * latent_dim)
        qkv = qkv.reshape(batch_size, sequence_length, self.num_heads, 3 * self.head_dim) # reshape into (batch, seq_len, num_heads, 3 * head_dim)
        qkv = qkv.permute(0, 2, 1, 3) # Rearrange to (batch, num_heads, seq_len, 3 * head_dim)
        q, k, v = qkv.chunk(3, dim=-1)  # Split the last dimension into q, k, v (each get last dimension of head_dim) (batch, num_heads, seq_len, head_dim)
        values, attention = scaled_dot_product_attention(q, k, v, mask) # Apply scaled dot product attention to get outputs (contextualized values) and attention weights
        values = values.reshape(batch_size, sequence_length, self.num_heads * self.head_dim) # Merge the heads (concatenate the last head_dim axis)
        out = self.linear_layer(values) # Final linear projection to match latent_dim
        return out

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, num_blocks, num_attention_heads, latent_dim):
        super.__init__()

        self.num_blocks = num_blocks
        self.num_attention_heads = num_attention_heads
        self.latent_dim = latent_dim

        self.layer_norm = nn.LayerNorm()

    def forward(self, x):
        pass
