## Masked Autoencoders


In [None]:
import random
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import os
os.listdir("/kaggle/input/ssl-dataset/ssl_dataset")


**Parameters and Labels**

In [None]:
DATA_ROOT = "/kaggle/input/ssl-dataset/ssl_dataset"
TRAIN_FOLDERS = [f"train.X{i}" for i in range(1,5)]
VAL_FOLDER   = "val.X"       
LR_PRETRAIN  = 1e-4
LR_LINEAR    = 1e-3
EPOCHS_PRE   = 20 
BATCH_SIZE      = 64       
EPOCHS_LINEAR= 50
MASK_RATIO   = 0.75       
IMG_SIZE     = 224
PATCH_SIZE    = 16
NUM_WORKERS = 8

*Transformations*

In [None]:
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

def _make_dataset(folders):
    images = []
    for fld in folders:
        fld_path = os.path.join(DATA_ROOT, fld)
        for root, _, files in os.walk(fld_path):
            for fname in files:
                if fname.lower().endswith((".jpg", ".png", ".jpeg")):
                    images.append(os.path.join(root, fname))
    return images


class MAEDataset(Dataset):
    def __init__(self, image_paths, transform, patch_size, mask_ratio):
        self.image_paths = image_paths
        self.transform   = transform
        self.patch_size  = patch_size
        self.mask_ratio  = mask_ratio
        self.num_patches = (IMG_SIZE // patch_size) ** 2

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        img = self.transform(img)                  

       
        patches = self.patchify(img)              

        
        num_mask = int(self.mask_ratio * self.num_patches)
        perm     = torch.randperm(self.num_patches)
        mask_idx = perm[:num_mask]
        mask     = torch.zeros(self.num_patches, dtype=torch.bool)
        mask[mask_idx] = True

        
        patches_masked = patches.clone()
        patches_masked[mask] = 0
        img_masked     = self.unpatchify(patches_masked) 

        return {
            'img_masked': img_masked,
            'mask':       mask,
            'img_orig':   img
        }

    def patchify(self, img):
      
        p = self.patch_size
        c, h, w = img.shape
        x = img.reshape(c, h//p, p, w//p, p)
        x = x.permute(1, 3, 0, 2, 4)   
        return x.reshape(-1, c * p * p)

    def unpatchify(self, patches):
       
        p = self.patch_size
        c = 3
        nh = nw = IMG_SIZE // p
        x  = patches.reshape(nh, nw, c, p, p)
        x  = x.permute(2, 0, 3, 1, 4)  
        return x.reshape(c, IMG_SIZE, IMG_SIZE)


image_paths   = _make_dataset(TRAIN_FOLDERS)
import random


image_paths = _make_dataset(TRAIN_FOLDERS)

random.seed(42)  
if len(image_paths) > 80000:
    image_paths = random.sample(image_paths, 90000)
train_dataset = MAEDataset(
    image_paths=image_paths,
    transform=train_transform,
    patch_size=PATCH_SIZE,
    mask_ratio=MASK_RATIO
)
train_loader  = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory= False
)


if __name__ == '__main__':
    with torch.no_grad():
        batch = next(iter(train_loader))
    print('img_masked:', batch['img_masked'].shape)  
    print('mask:',       batch['mask'].shape)        
    print('img_orig:',   batch['img_orig'].shape)  

In [None]:
import timm
import torch.nn as nn

# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# (tiny)
encoder = timm.create_model('vit_tiny_patch16_224', pretrained=True)


# 2) Remove its classification head
encoder.reset_classifier(0)  

# 3) Expose key pieces
patch_embed = encoder.patch_embed       
pos_embed   = encoder.pos_embed         
encoder_blocks = encoder.blocks          
encoder_norm   = encoder.norm            
embed_dim      = encoder.embed_dim       
num_patches    = patch_embed.num_patches 

# 4) Learnable mask token
mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

# 5) Simple decoder: 8 Transformer layers, with batch_first=True
decoder_layer = nn.TransformerEncoderLayer(
    d_model=embed_dim,
    nhead=4,
    dim_feedforward=embed_dim*2,
    batch_first=True            # <-- ensures inputs are [B, N, D]
)
decoder = nn.TransformerEncoder(decoder_layer, num_layers=4)

# 6) Reconstruction head: map each D-dimensional token → patch pixels
patch_size = patch_embed.patch_size[0] 
reconstruction_head = nn.Linear(
    embed_dim,
    patch_size * patch_size * 3
)


**Custom MAE Design**

In [None]:
class CustomMAE(nn.Module):
    def __init__(self, encoder, patch_embed, pos_embed,
                 encoder_blocks, encoder_norm,
                 mask_token, decoder, reconstruction_head,
                 patch_size=16):
        super().__init__()
        self.patch_size = patch_size
        self.encoder          = encoder
        self.patch_embed      = patch_embed
        self.pos_embed        = pos_embed
        self.encoder_blocks   = encoder_blocks
        self.encoder_norm     = encoder_norm
        self.mask_token       = mask_token
        self.decoder          = decoder
        self.reconstruction_head = reconstruction_head

        self.embed_dim = pos_embed.size(-1)

    def patchify(self, imgs):
        p = self.patch_size
        assert imgs.shape[2] % p == 0 and imgs.shape[3] % p == 0
        h = imgs.shape[2] // p
        w = imgs.shape[3] // p
        x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
        x = torch.einsum('nchpwq->nhwpqc', x)
        x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3))
        return x

    def forward(self, x, mask):
        B = x.size(0)
        x = self.patch_embed(x)                   
        x = x + self.pos_embed[:, 1:, :]          

        flat_x = x.reshape(B * x.size(1), -1)
        flat_mask = mask.unsqueeze(-1).expand(-1, -1, x.size(-1))
        flat_mask = flat_mask.reshape(B * x.size(1), -1)[:,0].bool()
        unmasked = flat_x[~flat_mask].reshape(B, -1, x.size(-1))

        enc = unmasked
        for blk in self.encoder_blocks:
            enc = blk(enc)
        enc = self.encoder_norm(enc)

        dec_seq = torch.zeros(B, x.size(1), self.embed_dim, device=x.device)
        for i in range(B):
            unmasked_idx = ~mask[i]
            masked_idx   = mask[i]
            dec_seq[i, unmasked_idx] = enc[i]
    
            num_masked = masked_idx.sum()
      
            mask_tokens = self.mask_token.expand(1, num_masked, self.embed_dim).squeeze(0)
    
            dec_seq[i, masked_idx] = mask_tokens

        dec_seq = dec_seq + self.pos_embed[:, 1:, :]

        dec_input = dec_seq.permute(1, 0, 2)
        dec_out   = self.decoder(dec_input)
        dec_out   = dec_out.permute(1, 0, 2)


      
        masked_out = []
        for i in range(B):
            masked_out.append(dec_out[i][mask[i]])
        masked_out = torch.stack([F.pad(mo, (0, 0, 0, mask.sum(1).max() - mo.size(0))) for mo in masked_out])

        preds = self.reconstruction_head(masked_out)
        return preds


In [None]:
model = CustomMAE(
    encoder, patch_embed, pos_embed,
    encoder_blocks, encoder_norm,
    mask_token, decoder, reconstruction_head
).to(device)


In [None]:
import torch.optim as optim
criterion = nn.MSELoss(reduction='none')
# AdamW optimizer
optimizer = optim.AdamW(model.parameters(), lr=LR_PRETRAIN, weight_decay=0.05)
# Cosine LR schedule with linear warmup
total_steps = len(train_loader) * EPOCHS_PRE
warmup_steps = int(0.1 * total_steps)
scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=total_steps - warmup_steps,
    eta_min=1e-6
)


**Pre-Train Loop**

In [None]:
import os
from tqdm import tqdm
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast


os.makedirs("/kaggle/working/", exist_ok=True)


device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.train()

scaler = GradScaler()

for epoch in range(1, EPOCHS_PRE + 1):
    epoch_loss = 0.0
    loop = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch}/{EPOCHS_PRE}")

    for batch_idx, batch in loop:
        imgs_masked = batch['img_masked'].to(device)
        masks       = batch['mask'].to(device)
        imgs_orig   = batch['img_orig'].to(device)

        with autocast():  
            preds = model(imgs_masked, mask=masks)
            with torch.no_grad():
                patchified = model.patchify(imgs_orig)

            B, L = masks.shape
            N, D = patchified.shape[1], patchified.shape[2]
            target_patches = []
            max_masked = masks.sum(dim=1).max()

            for b in range(B):
                masked_indices = masks[b].bool()
                selected = patchified[b][masked_indices]
                pad_size = max_masked - selected.shape[0]
                if pad_size > 0:
                    selected = F.pad(selected, (0, 0, 0, pad_size))  
                target_patches.append(selected)

            target_patches = torch.stack(target_patches).to(device)
            loss = ((preds - target_patches) ** 2).mean()

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        current_step = (epoch - 1) * len(train_loader) + batch_idx
        if current_step <= warmup_steps:
            lr = 1e-6 + (LR_PRETRAIN - 1e-6) * (current_step / warmup_steps)
            for pg in optimizer.param_groups:
                pg['lr'] = lr
        else:
            scheduler.step()

        epoch_loss += loss.item()
        loop.set_postfix(loss=loss.item(), lr=optimizer.param_groups[0]['lr'])

    avg_loss = epoch_loss / len(train_loader)
    print(f"\n✅ Epoch {epoch}/{EPOCHS_PRE} — Avg Loss: {avg_loss:.4f}")

    if True:
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': avg_loss,
        }
        torch.save(checkpoint, f"/kaggle/working/mae_checkpoint_epoch{epoch}.pth")
        print(f"💾 Checkpoint saved at epoch {epoch}")
