# Masked Autoencoders Are Scalable Vision Learners (Classification)

Conversion of train_classifier.py

## Imports

In [1]:
import os
from tqdm import tqdm
import math
import torch
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader

from model import *
from utils import setup_seed, EarlyStopper, ImageDataset

from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import time
import datetime

setup_seed(42)

2024-09-07 16:40:36.031638: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Load Dataset

In [2]:
# Model hyperparameters
batch_size = 128
lr = 1e-3
weight_decay = 0.05
num_epochs = 100
warmup_epoch = 5

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

train_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

test_tf = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

### CIFAR-10

In [3]:
from torchvision.datasets import CIFAR10

dataset_name = "cifar10"
num_classes = 10 # CIFAR-10

train_dataset = CIFAR10(root='./data', train=True, download=True, transform=train_tf)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = CIFAR10(root='./data', train=False, download=True, transform=test_tf)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)


Files already downloaded and verified
Files already downloaded and verified


### Custom Dataset

In [None]:
dataset_name = "example"
num_classes = 10

train_dataset = ImageDataset(root='./data', train=True, transform=train_tf)
val_dataset = ImageDataset(root='./data', train=False, transform=test_tf)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

## Model Setup

In [4]:
# Train from scratch
model = MAE_ViT()
writer = SummaryWriter(os.path.join('logs', dataset_name, 'scratch-cls'))
output_model_path = f"./models/vit-t-classifier-from_scratch-{dataset_name}.pt"

In [4]:
# OR Load pretrained model
pretrained_model_path = f"./models/vit-t-mae-{dataset_name}.pt"
model = torch.load(pretrained_model_path, map_location='cpu')
writer = SummaryWriter(os.path.join('logs', dataset_name, 'pretrain-cls'))
output_model_path = f"./models/vit-t-classifier-from_pretrained-{dataset_name}.pt"

In [5]:
model = ViT_Classifier(model.encoder, num_classes=num_classes).to(DEVICE)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr * batch_size / 256, betas=(0.9, 0.999), weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss()

lr_func = lambda epoch: min((epoch + 1) / (warmup_epoch + 1e-8), 0.5 * (math.cos(epoch / num_epochs * math.pi) + 1))
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)

acc_fn = lambda logit, label: torch.mean((logit.argmax(dim=-1) == label).float())
early_stopper = EarlyStopper()

## Model training

In [6]:
def mae_run(model, optimizer, train_loader, val_loader, criterion, acc_fn, lr_scheduler, early_stopper, num_epochs, output_model_path):
    print(f"Training MAE classification on {DEVICE}")
    best_accuracy_val = 0
    prev_time = time.time()
    for epoch in range(num_epochs):
        print('.' * 64)
        print(f"--- Epoch {epoch + 1}/{num_epochs} ---")
        
        ''' Model training'''
        model.train()
        train_loss = []
        train_accuracy = []
        pbar = tqdm(train_loader, leave=False)
        for i, batch in enumerate(pbar):
            img, label = batch
            img, label = img.to(DEVICE), label.to(DEVICE)
            optimizer.zero_grad()
            logits = model(img)
            loss = criterion(logits, label)
            acc = acc_fn(logits, label)
            loss.backward()
            optimizer.step()
            train_loss.append(loss.item())
            train_accuracy.append(acc.item())
            
            pbar.set_description(f'Epoch [{epoch + 1}/{num_epochs}]')
            
            # Determine approximate time left
            batches_done = epoch * len(train_loader) + i
            batches_left = num_epochs * len(train_loader) - batches_done
            time_left = datetime.timedelta(seconds=batches_left * (time.time() - prev_time))
            prev_time = time.time()

        avg_train_loss = sum(train_loss) / len(train_loss)
        avg_train_acc = sum(train_accuracy) / len(train_accuracy)
        
        ''' Model evalutation'''
        model.eval()
        with torch.no_grad():
            val_losses = []
            val_accuracy = []
            for img, label in val_loader:
                img, label = img.to(DEVICE), label.to(DEVICE)
                logits = model(img)
                loss = criterion(logits, label)
                acc = acc_fn(logits, label)
                val_losses.append(loss.item())
                val_accuracy.append(acc.item())
            avg_val_loss = sum(val_losses) / len(val_losses)
            avg_val_acc = sum(val_accuracy) / len(val_accuracy)        
        
        print(f"train_loss: {avg_train_loss:.4f} - train_accuracy: {avg_train_acc:.4f}")
        print(f"validation_loss: {avg_val_loss:.4f} - validation_accuracy: {avg_val_acc:.4f}")
        print(f"ETA: {time_left}")

        # Update learning rate
        prev_lr = lr_scheduler.get_last_lr()[0]
        lr_scheduler.step()
        curr_lr = lr_scheduler.get_last_lr()[0]
        
        if prev_lr > curr_lr:  
            print(f'Updating lr {prev_lr}->{curr_lr}')
        
        # Update best model on validation set
        if avg_val_acc > best_accuracy_val:
            best_accuracy_val = avg_val_acc
            torch.save(model, output_model_path)

        writer.add_scalars('cls/loss', {'train' : avg_train_loss, 'val' : avg_val_loss}, global_step=epoch)
        writer.add_scalars('cls/acc', {'train' : avg_train_acc, 'val' : avg_val_acc}, global_step=epoch)
        
        # Early stopping
        if early_stopper.early_stop(avg_val_loss):
            print(f'Stopping early at Epoch {epoch + 1}, min val loss failed to decrease after {early_stopper.get_patience()} epochs')
            break

In [None]:
mae_run(model, optimizer, train_loader, val_loader, criterion, acc_fn, lr_scheduler, early_stopper, num_epochs, output_model_path)