In [None]:
import torch
import torch.nn as nn
from torchvision import models

In [None]:
class DropBlock2D(nn.Module):
    def __init__(self, drop_prob=0.1, block_size=3):
        super().__init__()
        self.drop_prob = drop_prob
        self.block_size = block_size

    def forward(self, x):
        if not self.training or self.drop_prob == 0.:
            return x
        else:
            gamma = self._compute_gamma(x)
            mask = (torch.rand(x.shape[0], 1, x.shape[2], x.shape[3], device=x.device) < gamma).float()
            mask = self._compute_block_mask(mask)
            countM = mask.numel()
            count_ones = mask.sum()
            return mask * x * (countM / count_ones)

    def _compute_block_mask(self, mask):
        block_mask = nn.functional.max_pool2d(
            input=mask,
            kernel_size=(self.block_size, self.block_size),
            stride=(1, 1),
            padding=self.block_size // 2
        )
        return 1 - block_mask

    def _compute_gamma(self, x):
        return self.drop_prob / (self.block_size ** 2)

In [None]:
class Encoder(nn.Module):
    def __init__(self, drop_prob=0.3, block_size=3):
        super().__init__()
        resnet = models.resnet50(pretrained=True)

        # Freeze all parameters initially
        for param in resnet.parameters():
            param.requires_grad = False

        # Unfreeze parameters of the last block (layer4) for fine-tuning
        for param in resnet.layer4.parameters():
            param.requires_grad = True
        
        # Create a sequence of layers, inserting DropBlock after layer3 and layer4.
        # This is analogous to the original VGG implementation where DropBlock was
        # inserted after major feature extraction/downsampling stages.
        self.feature_extractor = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.relu,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            DropBlock2D(drop_prob=drop_prob, block_size=block_size),
            resnet.layer4,
            DropBlock2D(drop_prob=drop_prob, block_size=block_size)
        )

        # The output of ResNet-50's layer4 has 2048 channels.
        # The Decoder and embedding head expect 512 channels.
        # This bottleneck layer reduces the channel dimension to ensure compatibility.
        self.bottleneck = nn.Sequential(
            nn.Conv2d(2048, 512, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(512),
            nn.ReLU()
        )
        # Ensure the new bottleneck layer is trainable
        for param in self.bottleneck.parameters():
            param.requires_grad = True

    def forward(self, x):
        x = self.feature_extractor(x)
        x = self.bottleneck(x)   # Output shape: [B, 512, 7, 7]
        return x

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = nn.Sequential(
            #formula to calcu;ate dim in convtranspose is used such that always double easier to deal ig
            nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, output_padding=1),  # 7 → 14 (7(in)-1)*2 -2*1(pad) + 3(ker_size) + 1(out_pad)
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),  # 14 → 28
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),   # 28 → 56
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),    # 56 → 112
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),    # 112 → 224
            nn.ReLU(),
            nn.Conv2d(16, 3, kernel_size=3, padding=1),  # Keep output channels 3 (RGB)
            nn.Tanh()
        )

    def forward(self, x):
        return self.decoder(x)  # Output: [B, 3, 224, 224]


In [None]:
class MaskedAutoencoder(nn.Module):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

        self.embedding_head = nn.Sequential(
        nn.AdaptiveAvgPool2d(1),
        nn.Flatten(),             
        nn.Linear(512, 256)
        )
    

    def forward(self, masked_img, mask=None):
        latent = self.encoder(masked_img)
        recon = self.decoder(latent)
        embedding=self.embedding_head(latent)
        return recon,embedding

In [None]:
encoder = Encoder().to(device)
decoder = Decoder().to(device)
model = MaskedAutoencoder(encoder, decoder).to(device)