<a href="https://colab.research.google.com/github/wj-arit/VIT_implement/blob/main/VIT_base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [27]:
!pip install einops

import torch
from torch import nn
from torch.nn import Module, ModuleList



In [3]:
def pair(t):
    return t if isinstance(t,tuple) else(t,t)

In [35]:
class FeedForward(Module):
    def __init__(self,dim,hidden_dim,dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim,hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim,dim),
            nn.Dropout(dropout)
    )
    def forward(self,x):
        return self.net(x)


In [31]:
class Attention(Module):
    def __init__(self,dim,heads=8,dim_head=64,dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not(heads==1 and dim_head == dim)

        self.dim = dim
        self.heads = heads

        self.norm = nn.LayerNorm(dim)
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)

        self.to_qkv = nn.Linear(dim,inner_dim*3,bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim,dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self,x):
        self.x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3,dim=-1)
        q,k,v = map(lambda t:rearrange(t,'b n (h d) -> b h n d',h=self.heads),qkv)

        dots = torch.matmul(q,k.transpose(-1,-2))*self.scale

        attn = self.attend(dots)
        attn = self.dropout(attn)
        out = torch.matmul(attn,v)
        out = rearrange(out,'b h n d -> b n (h d)',h=self.heads)
        return self.to_out(out)

In [6]:
class Transformer(Module):
    def __init__(self,dim,depth,heads,dim_head,mlp_dim,dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = ModuleList([])

        for _ in range(depth):
            self.layers.append(ModuleList([
                Attention(dim,heads=heads,dim_head=dim_head,dropout=dropout),
                FeedForward(dim,mlp_dim,dropout=dropout)
        ]))
    def forward(self,x):
        for attn,ff in self.layers:
          x = attn(x) + x
          x = ff(x) + x
        return self.norm(x)



In [7]:
class ViT(Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width

        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
        num_cls_tokens = 1 if pool == 'cls' else 0

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.cls_token = nn.Parameter(torch.randn(num_cls_tokens, dim))
        self.pos_embedding = nn.Parameter(torch.randn(num_patches + num_cls_tokens, dim))

        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes) if num_classes > 0 else None

    def forward(self, img):
        batch = img.shape[0]
        x = self.to_patch_embedding(img)

        cls_tokens = repeat(self.cls_token, '... d -> b ... d', b = batch)
        x = torch.cat((cls_tokens, x), dim = 1)

        seq = x.shape[1]

        x = x + self.pos_embedding[:seq]
        x = self.dropout(x)

        x = self.transformer(x)

        if self.mlp_head is None:
            return x

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)

In [8]:
DRIVE_TAR = "/content/drive/MyDrive/data/imagenet100.tar.gz"
WORK_DIR = "/content/data"
IMAGENET_DIR = f"{WORK_DIR}"


In [18]:
import os

TAR_PATH = "/content/drive/MyDrive/data/imagenet100.tar.gz"

print("exists:", os.path.exists(TAR_PATH))
print("size (MB):", os.path.getsize(TAR_PATH) / 1024 / 1024 if os.path.exists(TAR_PATH) else None)


exists: True
size (MB): 16476.764285087585


In [17]:
import os

print(os.listdir("/content/data"))


[]


In [19]:
import os
import shutil
import tarfile

# ===== 경로 설정 =====
DRIVE_TAR = "/content/drive/MyDrive/data/imagenet100.tar.gz"
WORK_DIR = "/content/data"

# ===== 1) 작업 디렉토리 초기화 =====
if os.path.exists(WORK_DIR):
    shutil.rmtree(WORK_DIR)
os.makedirs(WORK_DIR, exist_ok=True)

print("WORK_DIR created:", WORK_DIR)
print("TAR exists:", os.path.exists(DRIVE_TAR))

# ===== 2) 압축 해제 =====
print("Extracting ImageNet100...")
with tarfile.open(DRIVE_TAR, "r:*") as tar:
    tar.extractall(WORK_DIR)
print("Extraction done.")

# ===== 3) train / val 자동 탐색 =====
IMAGENET_DIR = None
for root, dirs, files in os.walk(WORK_DIR):
    if "train" in dirs and "val" in dirs:
        IMAGENET_DIR = root
        break

assert IMAGENET_DIR is not None, "train/val 폴더를 찾지 못함"

# ===== 4) 결과 확인 =====
print("\n ImageNet root found:")
print("IMAGENET_DIR =", IMAGENET_DIR)
print("Contents:", os.listdir(IMAGENET_DIR))
print("train classes:", len(os.listdir(os.path.join(IMAGENET_DIR, "train"))))
print("val classes:", len(os.listdir(os.path.join(IMAGENET_DIR, "val"))))


WORK_DIR created: /content/data
TAR exists: True
Extracting ImageNet100...


  tar.extractall(WORK_DIR)


Extraction done.

 ImageNet root found:
IMAGENET_DIR = /content/data
Contents: ['val', 'train']
train classes: 100
val classes: 100


In [20]:
import os
import time
import glob
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


In [21]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)


device: cuda


In [22]:
IMG_SIZE = 224
BATCH_SIZE = 128
NUM_WORKERS = 4


In [23]:
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.08, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.485, 0.456, 0.406),
        std=(0.229, 0.224, 0.225),
    ),
])


In [24]:
DATA_ROOT = "/content/data"

train_dataset = datasets.ImageFolder(
    root=os.path.join(DATA_ROOT, "train"),
    transform=train_transform
)

val_dataset = datasets.ImageFolder(
    root=os.path.join(DATA_ROOT, "val"),
    transform=val_transform
)


In [25]:
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


In [36]:
from einops.layers.torch import Rearrange
from einops import repeat

model = ViT(
    image_size=224,
    patch_size=16,
    num_classes=100,
    dim=768,
    depth=12,
    heads=12,
    mlp_dim=3072,
    dropout=0.1,
    emb_dropout=0.1
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=3e-4,
    weight_decay=0.05
)

scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer,
    T_max=100
)


In [37]:
CKPT_DIR = "/content/drive/MyDrive/vit_checkpoints"
os.makedirs(CKPT_DIR, exist_ok=True)


In [38]:
def save_checkpoint(epoch, model, optimizer, scheduler, best_acc):
    path = os.path.join(CKPT_DIR, f"ckpt_epoch_{epoch}.pt")
    torch.save({
        "epoch": epoch,
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "scheduler": scheduler.state_dict(),
        "best_acc": best_acc
    }, path)


In [39]:
def load_latest_checkpoint(model, optimizer, scheduler):
    ckpts = glob.glob(os.path.join(CKPT_DIR, "ckpt_epoch_*.pt"))
    if len(ckpts) == 0:
        return 0, 0.0

    ckpts.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    ckpt_path = ckpts[-1]

    checkpoint = torch.load(ckpt_path, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    scheduler.load_state_dict(checkpoint["scheduler"])

    print(f"Resumed from {ckpt_path}")
    return checkpoint["epoch"] + 1, checkpoint["best_acc"]


In [40]:
def train_one_epoch(model, loader, optimizer):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


In [42]:
@torch.no_grad()
def validate(model, loader):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    for imgs, labels in loader:
        imgs = imgs.to(device)
        labels = labels.to(device)

        outputs = model(imgs)
        loss = criterion(outputs, labels)

        total_loss += loss.item() * imgs.size(0)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


In [43]:
from torch.utils.tensorboard import SummaryWriter

LOG_DIR = "/content/drive/MyDrive/vit_logs"
writer = SummaryWriter(log_dir=LOG_DIR)


In [None]:
from einops import rearrange, repeat

NUM_EPOCHS = 100

start_epoch, best_acc = load_latest_checkpoint(
    model, optimizer, scheduler
)

for epoch in range(start_epoch, NUM_EPOCHS):
    t0 = time.time()

    # ---- train / val ----
    train_loss, train_acc = train_one_epoch(
        model, train_loader, optimizer
    )
    val_loss, val_acc = validate(model, val_loader)

    # ---- lr schedule ----
    scheduler.step()
    current_lr = optimizer.param_groups[0]["lr"]

    # ---- best acc update ----
    if val_acc > best_acc:
        best_acc = val_acc

    # ---- checkpoint ----
    save_checkpoint(epoch, model, optimizer, scheduler, best_acc)

    # ---- TensorBoard logging ----
    writer.add_scalar("Loss/train", train_loss, epoch)
    writer.add_scalar("Loss/val", val_loss, epoch)

    writer.add_scalar("Accuracy/train", train_acc, epoch)
    writer.add_scalar("Accuracy/val", val_acc, epoch)
    writer.add_scalar("Accuracy/best_val", best_acc, epoch)

    writer.add_scalar("LR", current_lr, epoch)

    # ---- console log ----
    print(
        f"[Epoch {epoch}] "
        f"Train loss {train_loss:.4f} | acc {train_acc:.4f} || "
        f"Val loss {val_loss:.4f} | acc {val_acc:.4f} || "
        f"Best acc {best_acc:.4f} || "
        f"LR {current_lr:.6f} || "
        f"Time {time.time() - t0:.1f}s"
    )
writer.close()


In [None]:
%load_ext tensorboard
%tensorboard --logdir /content/drive/MyDrive/vit_logs
