In [5]:
!kaggle datasets download -d mikoajkolman/pokemon-images-first-generation17000-files -p "data/" -q

In [1]:
import zipfile

with zipfile.ZipFile('data/pokemon-images-first-generation17000-files.zip', 'r') as zip_ref:
    zip_ref.extractall('data/')

In [1]:
from torch.utils.data import random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms

# Define the transformations to apply to the images
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


# Load the entire dataset
dataset = ImageFolder('data/pokemon', transform=transform)

# Split the dataset into train and validation datasets
train_size = int(0.8 * len(dataset))  # 80% for training
val_size = len(dataset) - train_size  # 20% for validation
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])


In [6]:
import json

class_to_idx = dataset.class_to_idx
print(class_to_idx)

labels = list(class_to_idx.keys())

with open('data/pokemon-1st-gen-labels.json', 'w') as f:
    json.dump(labels, f)


{'Abra': 0, 'Aerodactyl': 1, 'Alakazam': 2, 'Arbok': 3, 'Arcanine': 4, 'Articuno': 5, 'Beedrill': 6, 'Bellsprout': 7, 'Blastoise': 8, 'Bulbasaur': 9, 'Butterfree': 10, 'Caterpie': 11, 'Chansey': 12, 'Charizard': 13, 'Charmander': 14, 'Charmeleon': 15, 'Clefable': 16, 'Clefairy': 17, 'Cloyster': 18, 'Cubone': 19, 'Dewgong': 20, 'Diglett': 21, 'Ditto': 22, 'Dodrio': 23, 'Doduo': 24, 'Dragonair': 25, 'Dragonite': 26, 'Dratini': 27, 'Drowzee': 28, 'Dugtrio': 29, 'Eevee': 30, 'Ekans': 31, 'Electabuzz': 32, 'Electrode': 33, 'Exeggcute': 34, 'Exeggutor': 35, 'Farfetchd': 36, 'Fearow': 37, 'Flareon': 38, 'Gastly': 39, 'Gengar': 40, 'Geodude': 41, 'Gloom': 42, 'Golbat': 43, 'Goldeen': 44, 'Golduck': 45, 'Graveler': 46, 'Grimer': 47, 'Growlithe': 48, 'Gyarados': 49, 'Haunter': 50, 'Hitmonchan': 51, 'Hitmonlee': 52, 'Horsea': 53, 'Hypno': 54, 'Ivysaur': 55, 'Jigglypuff': 56, 'Jolteon': 57, 'Jynx': 58, 'Kabutops': 59, 'Kadabra': 60, 'Kakuna': 61, 'Kangaskhan': 62, 'Kingler': 63, 'Koffing': 64, 'La

In [8]:
from torch.utils.data import DataLoader

# Create data loaders for the train and validation datasets
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [9]:
import torch
import torchvision.models as models

model = models.resnet18(pretrained=True)
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  



In [10]:
# Freeze all the pre-trained layers
for param in model.parameters():
    param.requires_grad = False

In [11]:
# Modify the last layer of the model
num_classes = len(labels)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

In [9]:
import wandb

def train(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    # Train the model for the specified number of epochs
    for epoch in range(num_epochs):
        # Set the model to train mode
        model.train()

        # Initialize the running loss and accuracy
        running_loss = 0.0
        running_corrects = 0

        # Iterate over the batches of the train loader
        for inputs, labels in train_loader:
            # Move the inputs and labels to the device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the optimizer gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            # Backward pass and optimizer step
            loss.backward()
            optimizer.step()

            # Update the running loss and accuracy
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        # Calculate the train loss and accuracy
        train_loss = running_loss / len(train_dataset)
        train_acc = running_corrects.double() / len(train_dataset)

        # Set the model to evaluation mode
        model.eval()

        # Initialize the running loss and accuracy
        running_loss = 0.0
        running_corrects = 0

        # Iterate over the batches of the validation loader
        with torch.no_grad():
            for inputs, labels in val_loader:
                # Move the inputs and labels to the device
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                loss = criterion(outputs, labels)

                # Update the running loss and accuracy
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

        # Calculate the validation loss and accuracy
        val_loss = running_loss / len(val_dataset)
        val_acc = running_corrects.double() / len(val_dataset)

        # Print the epoch results
        print('Epoch [{}/{}], train loss: {:.4f}, train acc: {:.4f}, val loss: {:.4f}, val acc: {:.4f}'
              .format(epoch+1, num_epochs, train_loss, train_acc, val_loss, val_acc))
        wandb.log({"epoch": epoch+1, "train_loss": train_loss, "train_acc": train_acc, "val_loss": val_loss, "val_acc": val_acc})


In [12]:
import wandb

last_layer_learning_rate = 0.01
last_layer_momentum = 0.9
last_layer_epoches = 5
full_layer_learning_rate = 0.001
full_layer_momentum = 0.001
full_layer_epoches = 10

wandb.init(
    project="fine-tuning-resnet18-to-pokemon-1st-gen",
    config={
        "last_layer_learning_rate": last_layer_learning_rate,
        "last_layer_momentum": last_layer_momentum,
        "last_layer_epochs": last_layer_epoches,
        "full_layer_learning_rate": full_layer_learning_rate,
        "full_layer_momentum": full_layer_momentum,
        "full_layer_epochs": full_layer_epoches,
        "architecture": "CNN",
        "dataset": "mikoajkolman/pokemon-images-first-generation17000-files",
    }
)

criterion = torch.nn.CrossEntropyLoss()

# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

# Fine-tune the last layer for a few epochs
optimizer = torch.optim.SGD(model.fc.parameters(), lr=last_layer_learning_rate, momentum=last_layer_momentum)
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=last_layer_epoches)

# Unfreeze all the layers and fine-tune the entire network for a few more epochs
for param in model.parameters():
    param.requires_grad = True
optimizer = torch.optim.SGD(model.parameters(), lr=full_layer_learning_rate, momentum=full_layer_momentum)
train(model, train_loader, val_loader, criterion, optimizer, num_epochs=full_layer_epoches)


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…



Epoch [1/5], train loss: 2.2959, train acc: 0.5347, val loss: 1.0452, val acc: 0.7795




Epoch [2/5], train loss: 0.8247, train acc: 0.8256, val loss: 0.7879, val acc: 0.8077
Epoch [3/5], train loss: 0.5786, train acc: 0.8759, val loss: 0.6857, val acc: 0.8332
Epoch [4/5], train loss: 0.4506, train acc: 0.9042, val loss: 0.6448, val acc: 0.8468
Epoch [5/5], train loss: 0.3731, train acc: 0.9226, val loss: 0.5991, val acc: 0.8605
Epoch [1/10], train loss: 0.2573, train acc: 0.9512, val loss: 0.4563, val acc: 0.8899
Epoch [2/10], train loss: 0.1882, train acc: 0.9670, val loss: 0.4189, val acc: 0.8969
Epoch [3/10], train loss: 0.1566, train acc: 0.9760, val loss: 0.3971, val acc: 0.9017
Epoch [4/10], train loss: 0.1355, train acc: 0.9802, val loss: 0.3786, val acc: 0.9057
Epoch [5/10], train loss: 0.1189, train acc: 0.9845, val loss: 0.3694, val acc: 0.9078
Epoch [6/10], train loss: 0.1071, train acc: 0.9860, val loss: 0.3587, val acc: 0.9081
Epoch [7/10], train loss: 0.0993, train acc: 0.9867, val loss: 0.3585, val acc: 0.9081
Epoch [8/10], train loss: 0.0905, train acc: 0.