In [None]:
# imports
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

# AS USED IN TRAINING
class CNN(nn.Module):
    def __init__(self, in_channel=3, num_classes=200):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channel, out_channels=8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.pool = nn.MaxPool2d(kernel_size=(2, 2))
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        self.fc1 = nn.Linear(16 * 8 * 8, num_classes)
        
    def forward(self, x):
        x = nn.functional.relu(self.conv1(x))
        x = self.pool(x)
        x = nn.functional.relu(self.conv2(x))
        x = self.pool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc1(x)
        return x

def load_model(checkpoint_path, num_classes):
    model = CNN(num_classes=num_classes)
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()  
    return model

def preprocess_image(image_path):
    my_transforms = transforms.Compose([
        transforms.Resize((32, 32)),    
        transforms.ToTensor(),           
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  
    ])
    image = Image.open(image_path).convert('RGB')  
    return my_transforms(image).unsqueeze(0)  

def predict(image_path, model, class_names):
    input_tensor = preprocess_image(image_path)
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    input_tensor = input_tensor.to(device)

    with torch.no_grad():
        outputs = model(input_tensor)
        _, predicted_class = outputs.max(1)  
    
    return class_names[predicted_class.item()], predicted_class.item()


# model
checkpoint_path = "my_checkpoint.pth.tar"

for i in range(1,1000):
    # image we want to predict
    image_path = f"datafile/train_images/train_images/{i}.jpg"

    # load classes
    class_dictionary = np.load("datafile/class_names.npy", allow_pickle=True).item()
    class_names = list(class_dictionary.keys())
    class_names = [name.split('.', 1)[1] for name in class_names]

    # load model with corect weights
    model = load_model(checkpoint_path, num_classes=len(class_names))

    # predict image: obtain class and class number - 1
    predicted_class, predicted_class_number = predict(image_path, model, class_names)

    print(f"Predicted Class: {predicted_class} classified as: {predicted_class_number + 1}")


  checkpoint = torch.load(checkpoint_path)


Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_Albatross classified as: 1
Predicted Class: Black_footed_A