# ResNet18 + CIFAR100 の性能確認



In [11]:
# Set Project Root Directory
import os

current_dir = os.getcwd()
root_marker = "common"
while current_dir != os.path.dirname(current_dir):
    if root_marker in os.listdir(current_dir):
        break
    current_dir = os.path.dirname(current_dir)

project_root = current_dir
if os.getcwd() != project_root:
    %cd {project_root};

In [20]:
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms.v2 as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torchvision.models import resnet18
from torcheval.metrics import MulticlassAccuracy
import lightning as L
from tqdm.autonotebook import tqdm

from common.utils.training import evaluate

In [18]:
def train_phase(fabric):
    L.seed_everything(seed=0xcafe, workers=True, verbose=True)
    batch_size = 32
    epochs = 300
    print_interval = 10

    train_transform = transforms.Compose([
        transforms.ToImage(),
        transforms.RandomResizedCrop(size=32, antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    train_ds = datasets.CIFAR100(
        root='./image_recognition/data',
        train=True,
        transform=train_transform,
        download=True,
    )
    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
        persistent_workers=True,
    )

    n_data = len(train_ds)

    net = resnet18(weights=None, num_classes=100)
    optimizer = optim.RAdam(
        params=net.parameters(),
        lr=1e-4,
        weight_decay=5e-4,
        decoupled_weight_decay=True,
        foreach=True,
    )
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=[100, 200],
        gamma=0.1,
    )
    criterion = nn.CrossEntropyLoss()

    model = torch.compile(net, mode='reduce-overhead')
    model, optimizer = fabric.setup(model, optimizer)
    loader = fabric.setup_dataloaders(train_loader)

    model.train()
    logs = defaultdict(list)
    for epoch in tqdm(range(epochs)):
        total_loss = 0
        for x, y in loader:
            optimizer.zero_grad()
            out = model(x)
            loss = criterion(out, y)
            fabric.backward(loss)
            optimizer.step()

            total_loss += loss.item() * len(x)

        scheduler.step()
        total_loss = total_loss / n_data
        logs['loss'].append(total_loss)

        if epoch == 0 or (epoch + 1) % print_interval == 0:
            print(f"EPOCH: {epoch + 1}/{epochs}, LOSS: {total_loss:.5f}")

    torch.save(net.to('cpu').state_dict(), './image_recognition/weights/CIFAR100_RN18.pth')
    return dict(logs)

In [19]:
fabric = L.Fabric(
    accelerator='cuda',
    devices=[0],
    precision='bf16-mixed',
)
logs = fabric.launch(train_phase)

Using bfloat16 Automatic Mixed Precision (AMP)
Seed set to 51966


Files already downloaded and verified


  0%|          | 0/300 [00:00<?, ?it/s]

EPOCH: 1/300, LOSS: 4.23593
EPOCH: 10/300, LOSS: 2.91062
EPOCH: 20/300, LOSS: 2.48746
EPOCH: 30/300, LOSS: 2.24010
EPOCH: 40/300, LOSS: 2.06269
EPOCH: 50/300, LOSS: 1.92576
EPOCH: 60/300, LOSS: 1.80724
EPOCH: 70/300, LOSS: 1.70563
EPOCH: 80/300, LOSS: 1.61176
EPOCH: 90/300, LOSS: 1.52736
EPOCH: 100/300, LOSS: 1.44388
EPOCH: 110/300, LOSS: 1.21908
EPOCH: 120/300, LOSS: 1.18727
EPOCH: 130/300, LOSS: 1.16606
EPOCH: 140/300, LOSS: 1.15199
EPOCH: 150/300, LOSS: 1.12781
EPOCH: 160/300, LOSS: 1.11832
EPOCH: 170/300, LOSS: 1.09859
EPOCH: 180/300, LOSS: 1.09236
EPOCH: 190/300, LOSS: 1.08570
EPOCH: 200/300, LOSS: 1.07308
EPOCH: 210/300, LOSS: 1.05013
EPOCH: 220/300, LOSS: 1.04341
EPOCH: 230/300, LOSS: 1.05114
EPOCH: 240/300, LOSS: 1.05063
EPOCH: 250/300, LOSS: 1.03405
EPOCH: 260/300, LOSS: 1.03838
EPOCH: 270/300, LOSS: 1.03224
EPOCH: 280/300, LOSS: 1.03460
EPOCH: 290/300, LOSS: 1.02707
EPOCH: 300/300, LOSS: 1.03347


In [27]:
def evaluate_phase(weight_name, device):
    test_transform = transforms.Compose([
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),
        transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    ])
    test_ds = datasets.CIFAR100(
        root='./image_recognition/data',
        train=False,
        transform=test_transform,
        download=True,
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=256,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
        drop_last=False,
        persistent_workers=True,
    )

    n_data = len(test_ds)

    net = resnet18(weights=None, num_classes=100).to(device)
    net.load_state_dict(torch.load('./image_recognition/weights/' + weight_name, weights_only=True))

    acc = evaluate(net, test_loader, device)
    print(f'Acc: {os.path.splitext(weight_name)[0]}, {acc:.5f}')

In [28]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
evaluate_phase('CIFAR100_RN18.pth', device=device)

Files already downloaded and verified
Acc: CIFAR100_RN18, 0.55480
