In [50]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

In [51]:
class CNN(nn.Module):
    def __init__(self, num_filters, kernel_size, dropout_rate, num_units1, num_units2):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, num_filters, kernel_size=kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(num_filters * self._get_conv_output_size(64, kernel_size, 2) ** 2, num_units1)
        self.fc2 = nn.Linear(num_units1, num_units2)
        self.fc3 = nn.Linear(num_units2, 5)  # 5 classes
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(-1, self.num_flat_features(x))
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

    def _get_conv_output_size(self, input_size, kernel_size, pool_size):
        return (input_size - kernel_size + 2) // pool_size + 1

In [52]:
def preprocess_image(image_path):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
    ])
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    image = image.unsqueeze(0)  # Add batch dimension
    return image

In [53]:
def predict_image(model, image_path, device, class_names):
    model.eval()
    image = preprocess_image(image_path).to(device)
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
    return class_names[predicted.item()]

In [54]:
saved_model_path = 'best_model_20240615_101908.pth'
num_filters = 32
kernel_size = 5
dropout_rate = 0.0
num_units1 = 64
num_units2 = 32

In [55]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [56]:
model = CNN(num_filters, kernel_size, dropout_rate, num_units1, num_units2).to(device)
model.load_state_dict(torch.load(saved_model_path, map_location=device))

<All keys matched successfully>

In [57]:
class_names = ['airfield', 'bus stand', 'canyon', 'market', 'temple']

In [82]:
image_path = "TestImages/00002385.jpg"

In [83]:
predicted_class = predict_image(model, image_path, device, class_names)
print(f'The predicted class for the image is: {predicted_class}')

The predicted class for the image is: canyon
