In [1]:
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder

dataset = ImageFolder('../data/merged')

In [5]:
import torch
import wandb

def train(model, optimizer, criterion, train_loader, test_loader, num_epochs, device):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        running_correct = 0
        total = 0
        
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            running_correct += torch.sum(predicted == labels.data)
            total += len(labels)
        
        model.eval()
        running_test_loss = 0.0
        running_test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                _, predicted = torch.max(outputs, 1)
                loss = criterion(outputs, labels)
                test_total += len(labels)

                running_test_loss += loss.item()
                running_test_correct += torch.sum(predicted == labels.data)
        
        log = {
            "epoch": epoch +1,
            "train_loss": running_loss / len(train_loader),
            "train_acc": running_correct / total,
            "test_loss": running_test_loss / len(test_loader),
            "test_acc": running_test_correct / test_total
        }
        print(log)
        wandb.log(log)


In [6]:
# SimpleCNN
import sys

sys.path.append('..')

import torch
import wandb
from safetensors.torch import save_file
from torch.utils.data import DataLoader
from torchvision import transforms

from src.SimpleCNN import SimpleCNN

model_name = "SimpleCNN"
learning_rate = 0.001
epochs = 5
image_size = 256
batch_size = 32

dataset.transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
])
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

wandb.init(
    project="pokemon-palworld",
    config={
        "model_name": model_name,
        "learning_rate": learning_rate,
        "architecture": "CNN",
        "dataset": "pokemon-palworld",
        "epochs": epochs,
        "image_size": image_size,
        "train_size": train_size,
        "test_size": test_size,
        "batch_size": batch_size,
    }
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN(image_size=image_size).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = torch.nn.CrossEntropyLoss()

train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=epochs, device=device)

save_file(model.state_dict(), f"../models/{model_name}_epoch{epochs}.safetensors")



{'epoch': 1, 'train_loss': 2.0058670265910528, 'train_acc': tensor(0.7567, device='cuda:0'), 'test_loss': 0.29579666070640087, 'test_acc': tensor(0.8963, device='cuda:0')}
{'epoch': 2, 'train_loss': 0.25133627191341174, 'train_acc': tensor(0.9179, device='cuda:0'), 'test_loss': 0.2427434492856264, 'test_acc': tensor(0.9081, device='cuda:0')}
{'epoch': 3, 'train_loss': 0.17301090178079903, 'train_acc': tensor(0.9396, device='cuda:0'), 'test_loss': 0.18246092647314072, 'test_acc': tensor(0.9488, device='cuda:0')}
{'epoch': 4, 'train_loss': 0.1345923683353855, 'train_acc': tensor(0.9507, device='cuda:0'), 'test_loss': 0.1817974865746995, 'test_acc': tensor(0.9396, device='cuda:0')}
{'epoch': 5, 'train_loss': 0.10805761968367733, 'train_acc': tensor(0.9619, device='cuda:0'), 'test_loss': 0.25233888734752935, 'test_acc': tensor(0.9304, device='cuda:0')}


In [7]:
# Fine Tuning from ResNet18
import torchvision.models as models
import wandb
from safetensors.torch import save_file
from torch.utils.data import DataLoader
from torchvision import transforms

model_name = "ResNet18_FineTuned"
last_layer_learning_rate = 0.01
last_layer_momentum = 0.9
last_layer_epoches = 5
full_layer_learning_rate = 0.001
full_layer_momentum = 0.001
full_layer_epoches = 10
image_size = 256
batch_size = 32

dataset.transform = transforms.Compose([
    transforms.Resize(image_size),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

wandb.init(
    project="pokemon-palworld",
    config={
        "model_name": model_name,
        "last_layer_learning_rate": last_layer_learning_rate,
        "last_layer_momentum": last_layer_momentum,
        "last_layer_epochs": last_layer_epoches,
        "full_layer_learning_rate": full_layer_learning_rate,
        "full_layer_momentum": full_layer_momentum,
        "full_layer_epochs": full_layer_epoches,
        "architecture": "CNN",
        "dataset": "pokemon-palworld",
        "image_size": image_size,
        "train_size": train_size,
        "test_size": test_size,
        "batch_size": batch_size,
    }
)


model = models.resnet18(pretrained=True)
for param in model.parameters():
    param.requires_grad = False
model.fc = torch.nn.Linear(model.fc.in_features, len(dataset.class_to_idx.keys()))
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

criterion = torch.nn.CrossEntropyLoss()

# Fine-tune the last layer for a few epochs
optimizer = torch.optim.SGD(model.fc.parameters(), lr=last_layer_learning_rate, momentum=last_layer_momentum)
train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=last_layer_epoches, device=device)

# Unfreeze all the layers and fine-tune the entire network for a few more epochs
for param in model.parameters():
    param.requires_grad = True
optimizer = torch.optim.SGD(model.parameters(), lr=full_layer_learning_rate, momentum=full_layer_momentum)
train(model, optimizer, criterion, train_dataloader, test_dataloader, num_epochs=full_layer_epoches, device=device)

save_file(model.state_dict(), f"../models/{model_name}_epoch{last_layer_epoches}_{full_layer_epoches}.safetensors")

0,1
epoch,▁▃▅▆█
test_acc,▁▃█▇▆
test_loss,█▅▁▁▅
train_acc,▁▆▇██
train_loss,█▂▁▁▁

0,1
epoch,5.0
test_acc,0.93045
test_loss,0.25234
train_acc,0.9619
train_loss,0.10806




{'epoch': 1, 'train_loss': 0.3158972087015475, 'train_acc': tensor(0.8906, device='cuda:0'), 'test_loss': 0.15371598408091813, 'test_acc': tensor(0.9606, device='cuda:0')}
{'epoch': 2, 'train_loss': 0.2375155989296521, 'train_acc': tensor(0.9323, device='cuda:0'), 'test_loss': 0.1321556754410267, 'test_acc': tensor(0.9567, device='cuda:0')}
{'epoch': 3, 'train_loss': 0.17485199307157018, 'train_acc': tensor(0.9448, device='cuda:0'), 'test_loss': 0.0980549325128474, 'test_acc': tensor(0.9777, device='cuda:0')}
{'epoch': 4, 'train_loss': 0.16236812840149165, 'train_acc': tensor(0.9603, device='cuda:0'), 'test_loss': 0.15898253349102257, 'test_acc': tensor(0.9633, device='cuda:0')}
{'epoch': 5, 'train_loss': 0.12463321681146529, 'train_acc': tensor(0.9586, device='cuda:0'), 'test_loss': 0.07893073613134523, 'test_acc': tensor(0.9724, device='cuda:0')}
{'epoch': 1, 'train_loss': 0.07786970144358445, 'train_acc': tensor(0.9777, device='cuda:0'), 'test_loss': 0.071285109166638, 'test_acc': t