In [1]:
# Description: This script performs inference on the test dataset and prints the classification metrics
#%%
import torch
import torch.nn as nn
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
import torchvision.models as models
from image_dataset import ImageDataset


BATCHSIZE = 4
NUM_CLASSES = 3


test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_dataset = ImageDataset(path_name="test", transform=test_transform)
test_dataloader = DataLoader(test_dataset, batch_size=BATCHSIZE, shuffle=True)


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = models.densenet121(pretrained=False) #Using an untrained model
num_ftrs = model.classifier.in_features
model.classifier = nn.Linear(num_ftrs, NUM_CLASSES)
model.load_state_dict(torch.load("model.pth"))
model.to(DEVICE)
model.eval()

true_labels = []
predicted_labels = []

Class path: test/.DS_Store
Class path: test/pneumonia
Adding file: test/pneumonia/person1438_bacteria_3721.jpeg
Adding file: test/pneumonia/00022965_004.png
Adding file: test/pneumonia/person1598_bacteria_4197.jpeg
Adding file: test/pneumonia/person616_bacteria_2487.jpeg
Adding file: test/pneumonia/person1018_virus_1706.jpeg
Adding file: test/pneumonia/person505_virus_1017.jpeg
Adding file: test/pneumonia/00010761_002.png
Adding file: test/pneumonia/person482_bacteria_2045.jpeg
Adding file: test/pneumonia/person475_virus_972.jpeg
Adding file: test/pneumonia/00009323_005.png
Adding file: test/pneumonia/person64_bacteria_316.jpeg
Adding file: test/pneumonia/person1458_virus_2501.jpeg
Adding file: test/pneumonia/person1526_virus_2660.jpeg
Adding file: test/pneumonia/person806_virus_1440.jpeg
Adding file: test/pneumonia/person968_virus_1642.jpeg
Adding file: test/pneumonia/person935_virus_1597.jpeg
Adding file: test/pneumonia/person673_virus_1263.jpeg
Adding file: test/pneumonia/person598_



In [2]:
# Perform inference on the test dataset
with torch.no_grad():
    for inputs, labels in test_dataloader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        print(outputs)

        _, predicted_class = torch.max(outputs, 1)
        true_labels.extend(labels.cpu().numpy())  # Store true labels
        predicted_labels.extend(predicted_class.cpu().numpy()) 


print(f"True Labels: {true_labels}")

Index: 1135, Dataset Size: 1384
Loaded image at test/covid/COVID-19 (860).jpg with label covid
Index: 537, Dataset Size: 1384
Loaded image at test/normal/00000052_000.png with label normal
Index: 1275, Dataset Size: 1384
Loaded image at test/covid/COVID-19 (500).jpg with label covid
Index: 41, Dataset Size: 1384
Loaded image at test/pneumonia/person264_virus_547.jpeg with label pneumonia
tensor([[ 6.6058, -2.3808, -4.3700],
        [ 0.9848,  1.7318, -2.6499],
        [ 2.5640, -0.8105, -1.5863],
        [-3.5588, -0.9699,  4.7928]])
Index: 1179, Dataset Size: 1384
Loaded image at test/covid/COVID19(337).jpg with label covid
Index: 170, Dataset Size: 1384
Loaded image at test/pneumonia/person973_virus_1647.jpeg with label pneumonia
Index: 570, Dataset Size: 1384
Loaded image at test/normal/00001154_006.png with label normal
Index: 203, Dataset Size: 1384
Loaded image at test/pneumonia/00016705_002.png with label pneumonia
tensor([[ 3.3976, -1.4656, -1.8657],
        [-3.0793, -0.4860, 

In [3]:
# Calculate classification metrics
accuracy = accuracy_score(true_labels, predicted_labels)
precision = precision_score(true_labels, predicted_labels, average="weighted")
recall = recall_score(true_labels, predicted_labels, average="weighted")
f1 = f1_score(true_labels, predicted_labels, average="weighted")
confusion = confusion_matrix(true_labels, predicted_labels)

In [4]:
# Print classification metrics
print(f"Accuracy: {accuracy * 100:.2f}%")
print(f"Precision: {precision * 100:.2f}%")
print(f"Recall: {recall * 100:.2f}%")
print(f"F1 Score: {f1 * 100:.2f}%")
print("Confusion Matrix:")
print(confusion)

Accuracy: 90.17%
Precision: 90.56%
Recall: 90.17%
F1 Score: 90.19%
Confusion Matrix:
[[434  18   4]
 [ 42 415   7]
 [ 16  49 399]]
