In [1]:
# =========================
# Contents
# =========================
# 1) Setup
# 2) module import
# 3) path import
# 4) util function
# 5) Configuration
# 6) Train/validation
# 7) test
# 8) Final pipeline
# 9) Conclusion
# =========================


In [1]:
# 1) Setup & Imports
import os, random
from dataclasses import dataclass
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms, datasets, utils as vutils

import timm  # ViT backbone

def set_seed(seed: int = 42):
    random.seed(seed); np.random.seed(seed)
    torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = True

device = "cuda" if torch.cuda.is_available() else "cpu"
set_seed(42)

print(f"[env] device={device}, torch={torch.__version__}, timm={timm.__version__}")

[env] device=cuda, torch=2.5.1+cu121, timm=1.0.21


In [None]:
# 2) Configuration & Paths

@dataclass
class CFG:
    img_size: int = 224
    patch_size: int = 16
    mask_ratio: float = 0.75
    enc_name: str = "vit_base_patch16_224"  # timm 모델명
    dec_dim: int = 384
    dec_depth: int = 6
    batch_size: int = 32
    num_workers: int = 4
    lr: float = 1e-4
    wd: float = 0.05
    epochs: int = 1 
    save_dir: str = "./runs/mae_demo"
    data_root: str = "./data"   

cfg = CFG()

# 디렉토리 준비
os.makedirs(cfg.save_dir, exist_ok=True)
os.makedirs(f"{cfg.save_dir}/ckpt", exist_ok=True)
os.makedirs(f"{cfg.save_dir}/viz",  exist_ok=True)


In [3]:
# 3) Utility Functions

#이미지를 패치 토큰들로 나눔
class Patchify(nn.Module):
    def __init__(self, patch_size=16):
        super().__init__()
        self.p = patch_size
    def forward(self, imgs):  # (B,C,H,W) -> (B, N, P*P*C)
        B, C, H, W = imgs.shape
        p = self.p
        assert H % p == 0 and W % p == 0, "H,W must be divisible by patch_size"
        h, w = H // p, W // p
        x = imgs.reshape(B, C, h, p, w, p).permute(0, 2, 4, 3, 5, 1).reshape(B, h*w, p*p*C)
        return x
#토큰을 다시 이미지로 합침침
class Unpatchify(nn.Module):
    def __init__(self, img_hw=224, patch_size=16, c=3):
        super().__init__()
        self.img_hw = img_hw; self.p = patch_size; self.c = c
    def forward(self, x):  # (B,N,P*P*C) -> (B,C,H,W)
        B, N, D = x.shape
        p, c = self.p, self.c
        h = w = self.img_hw // p
        x = x.reshape(B, h, w, p, p, c).permute(0,5,1,3,2,4).reshape(B, c, h*p, w*p)
        return x

def random_mask_indices(num_patches, mask_ratio=0.75):
    n_mask = int(num_patches * mask_ratio)
    ids = torch.randperm(num_patches)
    mask_ids = ids[:n_mask]
    keep_ids = ids[n_mask:]
    return keep_ids, mask_ids

def save_grid(tensor, path, nrow=8):
    path = Path(path); path.parent.mkdir(parents=True, exist_ok=True)
    vutils.save_image(tensor.clamp(0,1), str(path), nrow=nrow)


In [4]:
# 4) ViT Token Extractor
# timm의 ViT에서 "토큰 시퀀스 (B, N+1, D)"를 얻기 위한 도우미
# (forward_features는 보통 풀링 벡터를 내므로, 내부 모듈을 통해 토큰을 직접 구성)

def vit_tokens_from_timm(vit: nn.Module, imgs: torch.Tensor):
    """
    Return: (B, N+1, D)  [CLS + patch tokens]
    Assumes vit has attributes: patch_embed, cls_token, pos_embed, blocks, norm
    """
    x = vit.patch_embed(imgs)  # (B, N, D)
    B, N, D = x.shape
    cls_token = vit.cls_token.expand(B, -1, -1)  # (B,1,D)
    x = torch.cat((cls_token, x), dim=1)         # (B, N+1, D)

    # pos embed (224 입력 가정; 사이즈 다르면 보간 필요)
    if getattr(vit, "pos_embed", None) is not None:
        x = x + vit.pos_embed

    x = vit.pos_drop(x)
    for blk in vit.blocks:
        x = blk(x)
    x = vit.norm(x)
    return x  # (B, N+1, D)


In [5]:
# 5) MAE Model

class MAE(nn.Module):
    def __init__(self, cfg: CFG):
        super().__init__()
        self.cfg = cfg
        # Encoder (timm ViT)
        self.encoder = timm.create_model(cfg.enc_name, pretrained=False)
        emb_dim = self.encoder.embed_dim

        # Decoder (lightweight TransformerEncoder)
        self.mask_token = nn.Parameter(torch.zeros(1, 1, cfg.dec_dim))
        total_tokens = (cfg.img_size // cfg.patch_size) ** 2
        self.dec_pos = nn.Parameter(torch.zeros(1, total_tokens, cfg.dec_dim))
        self.enc_to_dec = nn.Linear(emb_dim, cfg.dec_dim)

        nhead = max(1, min(8, cfg.dec_dim // 64))  # 384→6
        layer = nn.TransformerEncoderLayer(d_model=cfg.dec_dim, nhead=nhead, batch_first=True)
        self.decoder = nn.TransformerEncoder(layer, num_layers=cfg.dec_depth)

        self.head = nn.Linear(cfg.dec_dim, cfg.patch_size * cfg.patch_size * 3)

        self.patchify = Patchify(cfg.patch_size)
        self.unpatchify = Unpatchify(cfg.img_size, cfg.patch_size, c=3)

        nn.init.trunc_normal_(self.mask_token, std=0.02)
        nn.init.trunc_normal_(self.dec_pos,   std=0.02)

    def forward(self, imgs: torch.Tensor):
        # 1) target patches
        target = self.patchify(imgs)        # (B, N, P2*C)
        B, N, D = target.shape

        # 2) mask indices per-sample
        keep_ids, mask_ids = [], []
        for _ in range(B):
            k, m = random_mask_indices(N, self.cfg.mask_ratio)
            keep_ids.append(k); mask_ids.append(m)
        keep_ids = torch.stack(keep_ids, 0).to(imgs.device)  # (B, Nk)
        mask_ids = torch.stack(mask_ids, 0).to(imgs.device)  # (B, Nm)

        # 3) Encoder tokens (CLS 제외)
        enc_all = vit_tokens_from_timm(self.encoder, imgs)   # (B, N+1, De)
        enc_tokens = enc_all[:, 1:, :]                       # (B, N,   De)

        # keep 위치만 취득
        enc_kept = torch.gather(
            enc_tokens, dim=1,
            index=keep_ids.unsqueeze(-1).expand(-1, -1, enc_tokens.size(-1))
        )  # (B, Nk, De)

        # 4) Decoder input: kept + mask
        Nk = enc_kept.size(1); Nm = mask_ids.size(1)
        dec_kept = self.enc_to_dec(enc_kept)      # (B, Nk, Dd)
        dec_mask = self.mask_token.expand(B, Nm, -1)
        dec_in = torch.cat([dec_kept, dec_mask], dim=1) + self.dec_pos[:, :Nk+Nm, :]

        dec_out = self.decoder(dec_in)            # (B, Nk+Nm, Dd)
        pred = self.head(dec_out[:, Nk:, :])      # (B, Nm, P2*C)

        target_masked = torch.gather(
            target, dim=1,
            index=mask_ids.unsqueeze(-1).expand(-1, -1, target.size(-1))
        )
        loss = F.mse_loss(pred, target_masked)
        return loss, pred, (keep_ids, mask_ids), target


In [6]:
# 6) Datasets & Dataloaders

def build_dataloaders(cfg: CFG):
    tfm = transforms.Compose([
        transforms.Resize((cfg.img_size, cfg.img_size)),
        transforms.ToTensor()
    ])

    train_dir = Path(cfg.data_root) / "train"
    val_dir   = Path(cfg.data_root) / "val"

    if train_dir.exists() and val_dir.exists():
        print(f"[data] Using ImageFolder at {cfg.data_root}")
        train_ds = datasets.ImageFolder(str(train_dir), transform=tfm)
        val_ds   = datasets.ImageFolder(str(val_dir),   transform=tfm)
    else:
        print("[data] ImageFolder not found — using FakeData for demo.")
        from torchvision.datasets import FakeData
        train_ds = FakeData(size=256, image_size=(3, cfg.img_size, cfg.img_size), transform=tfm)
        val_ds   = FakeData(size=64,  image_size=(3, cfg.img_size, cfg.img_size), transform=tfm)

    train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
                          num_workers=cfg.num_workers, pin_memory=True)
    val_dl   = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
                          num_workers=cfg.num_workers, pin_memory=True)
    return train_dl, val_dl

train_dl, val_dl = build_dataloaders(cfg)


[data] ImageFolder not found — using FakeData for demo.


In [7]:
# 7) Train / Validate / Test

def train_one_epoch(model, dl, opt, epoch, cfg: CFG):
    model.train()
    total = 0.0
    for imgs, _ in dl:
        imgs = imgs.to(device)
        loss, pred, idx, target = model(imgs)
        opt.zero_grad(set_to_none=True)
        loss.backward()
        opt.step()
        total += loss.item() * imgs.size(0)
    return total / len(dl.dataset)

@torch.no_grad()
def validate(model, dl, cfg: CFG, save_samples=False):
    model.eval()
    total = 0.0
    for i, (imgs, _) in enumerate(dl):
        imgs = imgs.to(device)
        loss, pred, (keep_ids, mask_ids), target = model(imgs)
        total += loss.item() * imgs.size(0)
        if save_samples and i == 0:
            save_grid(imgs[:16].cpu(), f"{cfg.save_dir}/viz/input_epoch.jpg", nrow=4)
    return total / len(dl.dataset)

@torch.no_grad()
def test_reconstruction(model, dl, cfg: CFG):
    model.eval()
    te_loss = None
    for i, (imgs, _) in enumerate(dl):
        imgs = imgs.to(device)
        loss, pred, (keep_ids, mask_ids), target = model(imgs)
        te_loss = loss.item()
        # 데모에선 입력만 저장 (전체 복원 이미지는 ids_restore 로직이 필요)
        if i == 0:
            save_grid(imgs[:16].cpu(), f"{cfg.save_dir}/viz/test_input.jpg", nrow=4)
            break
    return float(te_loss) if te_loss is not None else float("nan")


In [8]:
# 8) Inference & Visualization Preview

model = MAE(cfg).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.wd)
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=cfg.epochs)

best_val = 1e9
for epoch in range(1, cfg.epochs+1):
    tr = train_one_epoch(model, train_dl, opt, epoch, cfg)
    va = validate(model, val_dl, cfg, save_samples=(epoch % 1 == 0))
    sch.step()
    print(f"[{epoch:03d}] train {tr:.4f} | val {va:.4f}")
    if va < best_val:
        best_val = va
        torch.save(model.state_dict(), f"{cfg.save_dir}/ckpt/mae_best.pth")

te = test_reconstruction(model, val_dl, cfg)
print(f"[DONE] best_val={best_val:.4f} | test_recon_loss={te:.4f}")
print(f"[ARTIFACTS] {Path(cfg.save_dir).resolve()}")


[001] train 0.3982 | val 0.2139
[DONE] best_val=0.2139 | test_recon_loss=0.2138
[ARTIFACTS] F:\mae-from-scratch\notebook\runs\mae_demo
