In [1]:
import sys
sys.path.append("/home/sagemaker-user/memoMAE/")

import torch
import torch.nn as nn
from tqdm import tqdm
from Model.memoMAE import memoMAE
from Experiments.utils import load_backbone_from_ckpt
from Experiments.dataloader import ImagenetData
from torch.utils.data import TensorDataset, DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = 0
config_path = '/home/sagemaker-user/memoMAE/Configs/config_mae.yaml'
ckpt_path = '/home/sagemaker-user/memoMAE/checkpoints/mae-ViT-B-Patch-16-MemoCap100000-NumSim5-NosimEpochs20/last.ckpt'

In [3]:
mae_baseline = load_backbone_from_ckpt(config_path=config_path,
                                       ModelClass=memoMAE,
                                       ckpt_path=ckpt_path)

Loaded memoMAE. Missing keys: []
Unexpected keys: []


In [4]:
mae_baseline = mae_baseline.to(device)

In [5]:
data = ImagenetData(train_txt='/home/sagemaker-user/memoMAE/Data/train.txt', 
                    val_txt='/home/sagemaker-user/memoMAE/Data/val.txt', 
                    root_dir='/home/sagemaker-user/memoMAE/Data/',
                    batch_size=64)

In [6]:
val_loader = data.val_dataloader()
train_loader = data.train_dataloader()

In [8]:
memorize = True
k = 5
if memorize:
    for _ in range(1):
        with torch.no_grad():
            for images, labels in tqdm(train_loader):
                mae_baseline.forward_encoder_memo(images.to(0), 0., 0, memorize=memorize, fill_memory=True)

100%|██████████| 1562/1562 [06:52<00:00,  3.79it/s]


In [9]:
LATENTS_TRAIN = []
LABELS_TRAIN = []
with torch.no_grad():
    for images, labels in tqdm(train_loader):
        latents = mae_baseline.forward_encoder_memo(images.to(0), 0., k, memorize=False)[0]
        LATENTS_TRAIN.append(latents.mean(1).cpu())
        LABELS_TRAIN.append(labels.to(torch.long).cpu())

  x.storage().data_ptr() + x.storage_offset() * 4)
100%|██████████| 1562/1562 [18:23<00:00,  1.42it/s]


In [18]:
LATENTS_VAL = []
LABELS_VAL = []
with torch.no_grad():
    for images, labels in tqdm(val_loader):
        latents = mae_baseline.forward_encoder_memo(images.to(0), 0., 5, memorize=True)[0]
        LATENTS_VAL.append(latents.mean(1).cpu())
        LABELS_VAL.append(labels.to(torch.long).cpu())

100%|██████████| 547/547 [06:24<00:00,  1.42it/s]


In [19]:
device = 'cpu'

In [20]:
# ----- 1. Stack features & labels -----
feats_train = torch.cat(LATENTS_TRAIN, dim=0)   # (N, D)
labels_train = torch.cat(LABELS_TRAIN, dim=0) # (N,)
feats_test = torch.cat(LATENTS_VAL, dim=0)   # (N, D)
labels_test = torch.cat(LABELS_VAL, dim=0) # (N,)
N, D = feats_train.shape
num_classes = int(labels_train.max().item()) + 1
# move to device
if device != 'cpu':
    feats_train = feats_train.to(device)
    labels_train = labels_train.to(device)
    feats_test = feats_test.to(device)
    labels_test = labels_test.to(device)
# ----- 3. Linear classifier with BatchNorm -----
clf = nn.Sequential(
    nn.BatchNorm1d(D, affine=False),
    nn.Linear(D, num_classes)
).to(device)
# MAE LR scaling rule
batch_size = 2048
batch_size = min(batch_size, feats_train.size(0))
base_lr = 0.1
max_lr = base_lr * batch_size / 256.0
train_loader = DataLoader(
    TensorDataset(feats_train, labels_train),
    batch_size=batch_size,
    shuffle=True,
)
optimizer = torch.optim.SGD(
    clf.parameters(),
    lr=max_lr,          # final LR, warmup scheduler will scale it
    momentum=0.9,
    weight_decay=0.0,
)
# ----- 4. Linear warmup + cosine schedulers -----
warmup_epochs = 10
total_epochs = 100
scheduler_warmup = torch.optim.lr_scheduler.LinearLR(
    optimizer,
    start_factor=1e-6,       # start at near-zero LR
    end_factor=1.0,          # warm up to max_lr
    total_iters=warmup_epochs,
)
scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=total_epochs - warmup_epochs,
    eta_min=1e-6,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[scheduler_warmup, scheduler_cosine],
    milestones=[warmup_epochs],
)
criterion = nn.CrossEntropyLoss()
# ----- 5. Training -----
clf.train()
with torch.enable_grad():
    for epoch in tqdm(range(total_epochs)):
        for x, y in train_loader:
            optimizer.zero_grad()
            logits = clf(x)
            loss = criterion(logits, y)
            loss.backward()
            optimizer.step()
        scheduler.step()
# ----- 6. Evaluation -----
clf.eval()
with torch.no_grad():
    preds = clf(feats_test).argmax(dim=1)
    acc = (preds == labels_test).float().mean().item()

print(acc)

100%|██████████| 100/100 [01:42<00:00,  1.02s/it]

0.2159428596496582





In [29]:
mae = 0.5870571136474609