## Testing
Testing notebook for animal classification models

In [2]:
import os
from pathlib import Path

import torch, torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

print(f"Pytorch version: {torch.__version__}")

Pytorch version: 2.1.1


### SETUP

In [4]:
# GPU or CPU?
device = (
    "cuda" if torch.cuda.is_available()
    else "mps" if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using mps device


### Load Dataset

1. Make sure the animals dataset is downloaded into the "datasets/animals" directory
   (https://www.kaggle.com/datasets/npurav/animal-classification-dataset)
2. Add augmented images to the dataset.
3. Load the dataset then split into train and test data

In [5]:
# Get root directory
BASE_DIR = Path.cwd().resolve().parent.parent

# Path to your dataset
dataset_path = BASE_DIR / 'datasets/animals/dataset'

# Define transformations (you can customize these based on your needs)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to a common size
    transforms.ToTensor(),
])

# Create ImageFolder dataset
dataset = ImageFolder(dataset_path, transform=transform)
class_names = dataset.classes
print(f'Total number of classes: {len(class_names)}')

Total number of classes: 117


In [26]:
# Image augmentation
# Define image transformations for augmentation
augmentation_transform = transforms.Compose([
    #transforms.RandomResizedCrop(224),        # Random crop and resize
    transforms.RandomHorizontalFlip(),        # Random horizontal flip
    transforms.RandomVerticalFlip(),          # Random vertical flip
    # transforms.RandomRotation(degrees=15),    # Random rotation (up to 15 degrees)
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),  # Color jitter
    transforms.RandomGrayscale(p=0.1),        # Randomly convert to grayscale
    transforms.RandomPerspective(distortion_scale=0.2, p=0.2),  # Random perspective
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Convert to PyTorch tensor
])

# Create augmented dataset
for i in range(1):
    # Apply augmentation to the entire dataset
    augmented_dataset = ImageFolder(dataset_path, transform=augmentation_transform)
    # Combine the original and augmented datasets
    dataset = augmented_dataset
    
print(f'Total number of images: {len(dataset)}')

Total number of images: 19225


In [27]:
# Create data loaders
batch_size = 64

test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

In [28]:
# Model architecture
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(3, 3)
        self.conv2 = nn.Conv2d(16, 64, 3)
        self.conv3 = nn.Conv2d(64, 128, 3)
        self.fc1 = nn.Linear(6272, 3000)
        self.fc2 = nn.Linear(3000, 1000)
        self.fc3 = nn.Linear(1000, len(class_names))

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [29]:
model = Net()
model.to(device)
model.eval()
print(f'Trainable params: {sum(p.numel() for p in net.parameters() if p.requires_grad)}')

Trainable params: 22021469


In [34]:
# Load the trained model
large = 'models/ac_large.pth'
medium = 'models/ac_med.pth'
medium = 'models/ac_small.pth'
PATH = BASE_DIR / large

model.load_state_dict(torch.load(PATH))

<All keys matched successfully>

In [31]:
# Test the model on the test dataset
correct = 0
total = 0

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

accuracy = 100 * correct / total
print(f'Accuracy on the test dataset: {accuracy:.2f}%')

Accuracy on the test dataset: 91.46%
