In [None]:
import torch
import torch.nn as nn
import torchvision.models as models

from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from datetime import datetime

In [None]:
IMAGE_SIZE = 224
BATCH_SIZE = 16
NUM_WORKERS = 2
NUM_CLASSES = 9
NUM_EPOCHS = 30
MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]
LEARNING_RATE = 1e-3

In [None]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(IMAGES_SIZE, scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.Normalize(MEAN, STD),
    transforms.ToTensor(),
])

val_transforms = transforms.Compose([
    transforms.Resize(int(IMAGE_SIZE * 1.14)),
    transforms.CenterCrop(IMAGE_SIZE),
    transforms.Normalize(MEAN, STD),
    transforms.ToTensor(),
])

train_dataset_path = "../../../dataset/train"
val_dataset_path = "../../../dataset/val"

train_dataset = datasets.ImageFolder(root=train_dataset_path, transforms=train_transforms)
val_dataset = datasets.ImageFolder(root=val_dataset_path, transforms=val_transforms)

train_loader = DataLoader(
    train_dataset,
    shuffle=True,
    pin_memory=True,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)
val_loader = DataLoader(
    val_dataset,
    shuffle=False,
    pin_memory=True,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
)

In [None]:
if __name__ == "__main__":
    images, labels = next(iter(train_loader))
    print(f"Batch shape: {images.shape}")
    print(f"Labels shape: {labels.shape}")

In [None]:
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters, lr=LEARNING_RATE)

In [None]:
def train(model, loader, optimizer, criterion, device):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        # outputs is a tensor: A 2D tensor of shape [batch_size, num_classes].
        # Each row contains logits (raw scores) for each class.
        # tensor([[1.2, 0.3, 2.5],     # sample 1 → class 2
        #        [0.1, 4.1, 0.2]])     # sample 2 → class 1
        outputs = model(images)

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        # argmax(dim=1) is find the index of the maximum score 
        # (the most confident score)
        # outputs: (2, 1)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

    calc_loss = running_loss/total
    calc_accuracy = correct/total

    return calc_loss, calc_accuracy

def validate(model, loader, criterion, device):
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            val_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

    calc_loss = val_loss/total
    calc_accuracy = correct/total

    return calc_loss, calc_accuracy

In [None]:
best_val_loss = float('inf')

for epoch in range(NUM_EPOCHS):
    train_loss, train_acc = train(model, train_loader, optimizer, criterion, device)
    val_loss, val_acc = validate(model, train_loader, criterion, device)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}: ")
    print(f"Train loss {train_loss:.4f}, acc {train_acc:.4f} |")
    print(f"Val loss {val_loss:.4f}, acc {val_acc:.4f}")
    
    # checkpoint best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), "best_resnet50.pth")