<a href="https://colab.research.google.com/github/praneshnikhar/DL-projects/blob/main/MAE_on_satellite_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H', W']
        x = x.flatten(2).transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

In [3]:
class MAE(nn.Module):
    def __init__(self, img_size=128, patch_size=16, in_chans=3, embed_dim=768, mask_ratio=0.75):
        super().__init__()
        self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(embed_dim, nhead=8), num_layers=6)
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(embed_dim, nhead=8), num_layers=4)
        self.mask_ratio = mask_ratio
        self.decoder_pred = nn.Linear(embed_dim, patch_size * patch_size * in_chans)

    def random_masking(self, x):
        N, L, D = x.shape
        len_keep = int(L * (1 - self.mask_ratio))
        noise = torch.rand(N, L, device=x.device)
        ids_shuffle = torch.argsort(noise, dim=1)
        ids_restore = torch.argsort(ids_shuffle, dim=1)
        ids_keep = ids_shuffle[:, :len_keep]
        x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
        return x_masked, ids_restore, ids_keep

    def forward(self, x):
        patches = self.patch_embed(x)  # [B, num_patches, embed_dim]
        x_masked, ids_restore, ids_keep = self.random_masking(patches)
        enc_out = self.encoder(x_masked)
        mask_tokens = torch.zeros(x.size(0), patches.size(1) - x_masked.size(1), enc_out.size(2), device=x.device)
        dec_input = torch.cat([enc_out, mask_tokens], dim=1)
        dec_input = torch.gather(dec_input, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, enc_out.size(2)))
        dec_out = self.decoder(dec_input, enc_out)
        pred = self.decoder_pred(dec_out)
        return pred, patches
