In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset
from torchvision import transforms
import random
import numpy as np
from PIL import Image

In [2]:
class MAEDataset(Dataset):
    def __init__(self, image_paths, patch_size=16, mask_ratio=0.75, transform=None):
        """
        Args:
            image_paths (list): List of paths to the images in the dataset.
            patch_size (int): Size of the patches to split the image into (default is 16).
            mask_ratio (float): Fraction of the patches to be masked (default is 0.75).
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.image_paths = image_paths
        self.patch_size = patch_size
        self.mask_ratio = mask_ratio
        self.transform = transform

        # Compute the number of patches in a single image
        self.img_size = 224  # Assuming image size is 224x224
        self.num_patches = (self.img_size // self.patch_size) ** 2  # Number of patches (grid)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        
        # Apply any transformations (e.g., resize, normalization)
        if self.transform:
            image = self.transform(image)
        
        # Create patches
        patches = self.image_to_patches(image)
        
        # Mask a portion of the patches
        masked_patches, mask = self.mask_patches(patches)
        
        # Reconstruct the sample
        return masked_patches, patches, mask

    def image_to_patches(self, image):
        """
        Convert an image into patches.
        Args:
            image (PIL Image or Tensor): Image to be split into patches.
        Returns:
            Tensor: Flattened patch embeddings (num_patches x patch_dim)
        """
        # Assuming image is already a tensor after transform
        patches = image.unfold(0, self.patch_size, self.patch_size).unfold(1, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(-1, self.patch_size * self.patch_size * 3)  # Flatten the patches
        return patches

    def mask_patches(self, patches):
        """
        Randomly mask a portion of the patches.
        Args:
            patches (Tensor): Flattened patch embeddings (num_patches x patch_dim)
        Returns:
            Tensor: Masked patches (num_patches x patch_dim)
            Tensor: Binary mask of which patches were masked
        """
        num_patches_to_mask = int(self.num_patches * self.mask_ratio)
        mask_indices = random.sample(range(self.num_patches), num_patches_to_mask)
        
        # Create mask
        mask = torch.ones(self.num_patches)
        mask[mask_indices] = 0  # Masked patches will have a 0 in the mask

        # Mask the patches
        masked_patches = patches.clone()
        masked_patches[mask_indices] = 0  # Replace the masked patches with zeros (or any value)

        return masked_patches, mask
        

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, dim, num_heads, ff_dim, dropout_rate=0.1):
        """
        Args:
            dim (int): The dimensionality of the input/embedding dimension (D).
            num_heads (int): The number of attention heads for Multi-Head Self Attention.
            ff_dim (int): The dimensionality of the feed-forward layer.
            dropout_rate (float): Dropout rate for attention and feed-forward layers.
        """
        super(EncoderBlockPreNorm, self).__init__()

        # Multi-Head Self Attention (MHSA)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout_rate)
        
        # Feed-Forward Network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(dim, ff_dim),  # First Linear layer (Dim -> FF Dim)
            nn.GELU(),                # GELU activation
            nn.Dropout(dropout_rate), # Dropout layer
            nn.Linear(ff_dim, dim)   # Second Linear layer (FF Dim -> Dim)
        )
        
        # Layer Normalization layers (applied before attention and feed-forward)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Dropout layers
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, mask=None):
        """
        Forward pass through the encoder block.
        
        Args:
            x (Tensor): Input tensor of shape (seq_len, batch_size, dim)
            mask (Tensor, optional): Optional attention mask for padding (default: None)
        
        Returns:
            Tensor: Output tensor after the encoder block.
        """
        # Apply LayerNorm before MHSA
        x_1 = self.norm1(x)
        attn_output, _ = self.attn(x_1, x_1, x_1, attn_mask=mask)
        # Add residual connection and apply dropout
        x = x + self.dropout1(attn_output)

        # Apply LayerNorm before FFN
        x_2 = self.norm2(x)
        ffn_output = self.ffn(x_2)
        # Add residual connection and apply dropout
        x = x + self.dropout2(ffn_output)

        return x


In [None]:
class DecoderBlock(nn.Module):
    def __init__(self, dim, num_heads, ff_dim, dropout_rate=0.1):
        """
        Args:
            dim (int): The dimensionality of the input/embedding dimension (D).
            num_heads (int): The number of attention heads for Multi-Head Self Attention and Cross-Attention.
            ff_dim (int): The dimensionality of the feed-forward layer.
            dropout_rate (float): Dropout rate for attention and feed-forward layers.
        """
        super(DecoderBlockPreNorm, self).__init__()

        # Cross-Attention: Attend to the encoder's output
        self.cross_attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, dropout=dropout_rate)
        
        # Feed-Forward Network (FFN)
        self.ffn = nn.Sequential(
            nn.Linear(dim, ff_dim),  # First Linear layer (Dim -> FF Dim)
            nn.GELU(),                # GELU activation
            nn.Dropout(dropout_rate), # Dropout layer
            nn.Linear(ff_dim, dim)   # Second Linear layer (FF Dim -> Dim)
        )
        
        # Layer Normalization layers (applied before cross-attention and FFN)
        self.norm1 = nn.LayerNorm(dim)
        self.norm2 = nn.LayerNorm(dim)

        # Dropout layers
        self.dropout1 = nn.Dropout(dropout_rate)
        self.dropout2 = nn.Dropout(dropout_rate)

    def forward(self, x, encoder_output, mask=None, cross_mask=None):
        """
        Forward pass through the decoder block.
        
        Args:
            x (Tensor): Input tensor of shape (seq_len, batch_size, dim) for the decoder.
            encoder_output (Tensor): Output tensor from the encoder of shape (seq_len, batch_size, dim).
            mask (Tensor, optional): Optional attention mask for the decoder (default: None).
            cross_mask (Tensor, optional): Optional cross-attention mask for encoder-decoder attention (default: None).
        
        Returns:
            Tensor: Output tensor after the decoder block.
        """
        # Apply LayerNorm before Cross-Attention (decoder input and encoder output)
        x_1 = self.norm1(x)
        cross_attn_output, _ = self.cross_attn(x_1, encoder_output, encoder_output, attn_mask=cross_mask)
        # Add residual connection and apply dropout
        x = x + self.dropout1(cross_attn_output)

        # Apply LayerNorm before FFN
        x_2 = self.norm2(x)
        ffn_output = self.ffn(x_2)
        # Add residual connection and apply dropout
        x = x + self.dropout2(ffn_output)

        return x
        