# Masked Autoencoders Are Scalable Vision Learners (Pretrain)

Conversion of mae_pretrain.py

## Imports

In [None]:
import os
import math
from tqdm import tqdm
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import time
import datetime

from model import *
from utils import setup_seed, EarlyStopper, ImageDataset

setup_seed(42)

## Load Dataset

In [2]:
# Model hyperparameters
batch_size = 128
lr = 1.5e-4
weight_decay = 0.05
mask_ratio = 0.75
num_epochs = 2000
warmup_epoch = 200

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### CIFAR10

In [None]:
from torchvision.datasets import CIFAR10

train_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CIFAR10(root='./data', train=False, download=True, transform=test_tf)

dataset_name = "cifar10"


### Custom dataset

In [13]:
train_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

test_tf = transforms.Compose([
    transforms.ToPILImage(),
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

data_path = os.path.abspath("./data")

train_dataset = ImageDataset(root=data_path, train=True, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = ImageDataset(root=data_path, train=False, transform=test_tf)

dataset_name = "example"

## Model Setup

In [4]:
output_model_path = f"./models/vit-t-mae-{dataset_name}.pt"

writer = SummaryWriter(os.path.join('logs', dataset_name, 'mae-pretrain'))
model = MAE_ViT(mask_ratio=mask_ratio).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr * batch_size / 256, betas=(0.9, 0.95), weight_decay=weight_decay)

lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / num_epochs * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)

early_stopper = EarlyStopper()

## Model training

In [5]:
def mae_run(model, optimizer, train_loader, val_dataset, lr_scheduler, early_stopper, num_epochs, mask_ratio, output_model_path):
    print(f"Training MAE on {DEVICE}")
    prev_time = time.time()
    for epoch in range(num_epochs):
        print('.' * 64)
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        
        model.train()
        losses = []
        pbar = tqdm(train_loader, leave=False)
        for i, batch in enumerate(pbar):
            img, label = batch
            img = img.to(DEVICE)
            optimizer.zero_grad()

            predicted_img, mask = model(img)
            loss = torch.mean((predicted_img - img) ** 2 * mask) / mask_ratio
            loss.backward()
            optimizer.step()
            losses.append(loss.item())

            # Determine approximate time left
            batches_done = epoch * len(train_loader) + i
            batches_left = num_epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()
        
            pbar.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')

        # Update learning rate
        prev_lr = lr_scheduler.get_last_lr()[0]
        lr_scheduler.step()
        curr_lr = lr_scheduler.get_last_lr()[0]
        
        if prev_lr > curr_lr:  
            print(f'Updating lr {prev_lr}->{curr_lr}')
        
        
        avg_loss = sum(losses) / len(losses)
        writer.add_scalar('mae_loss', avg_loss, global_step=epoch)
        
        print(f"train_loss: {avg_loss} ETA: {time_left}")
        
        ''' Visualize the first 16 predicted images on val dataset'''
        model.eval()
        with torch.no_grad():
            val_img = torch.stack([val_dataset[i][0] for i in range(16)])
            val_img = val_img.to(DEVICE)
            predicted_val_img, mask = model(val_img)
            predicted_val_img = predicted_val_img * mask + val_img * (1 - mask)
            img = torch.cat([val_img * (1 - mask), predicted_val_img, val_img], dim=0)
            img = rearrange(img, '(v h1 w1) c h w -> c (h1 h) (w1 v w)', w1=2, v=3)
            writer.add_image('mae_image', (img + 1) / 2, global_step=epoch)
        
        torch.save(model, output_model_path)
        
        # Early stopping
        if early_stopper.early_stop(avg_loss):
            print(f'Stopping early at Epoch {epoch + 1}, min loss failed to decrease after {early_stopper.get_patience()} epochs')
            break

In [None]:
mae_run(model, optimizer, train_loader, val_dataset, lr_scheduler, early_stopper, num_epochs, mask_ratio, output_model_path)