In [None]:
from FileLoader import load_file
import torch
import torchvision
from torchvision import transforms
import torchmetrics
import distutils.version
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from tqdm.notebook import tqdm
import numpy as np
import os

In [None]:
ckpt_path = "./logs_resnet18/lightning_logs/version_14/checkpoints" #Add the path to the folder with checkpoints
processed_path = "./processed" #Add path to the folder with processed images

In [None]:
checkpoints = []
for checkpoint in os.listdir(ckpt_path):
    checkpoints.append(checkpoint)

In [None]:
print(checkpoints)

In [None]:
test_transforms = transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize([0.49044,], [0.24787,])
])

test_dataset = torchvision.datasets.DatasetFolder(
    f"{processed_path}/test/", loader=load_file, extensions="npy", transform=test_transforms)

print(f"There are {len(test_dataset)} test images")
np.unique(test_dataset.targets, return_counts=True)

In [None]:
#Initiate the model that needs to be evaluated.
#Currently it is ResNet18
#Needs to be adjusted to evaluate another model

In [None]:
class PneumoniaModel(pl.LightningModule):
    def __init__(self, weight=(20672/6012)):
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = torch.nn.Linear(in_features=512, out_features=1)
        self.loss_fn = torch.nn.BCEWithLogitsLoss(pos_weight=torch.tensor([weight]))
        
    def forward(self, data):
        pred = self.model(data)
        return pred

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
#The below function evaluates the model for all the checkpoints in the speficied directory
#Only the results for accuracy greater than 0.8 are printed

In [None]:
for checkpoint in checkpoints:
    checkpoint_path = (f"{ckpt_path}/{checkpoint}")
    model = PneumoniaModel.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.to(device);

    preds = []
    labels = []
    
    with torch.no_grad():
        for data, label in tqdm(test_dataset):
            data = data.to(device).float().unsqueeze(0)
            pred = torch.sigmoid(model(data)[0].cpu())
            preds.append(pred)
            labels.append(label)
    preds = torch.tensor(preds)
    labels = torch.tensor(labels).int()
    acc = torchmetrics.classification.BinaryAccuracy()(preds, labels)
    if acc>0.8:
        precision = torchmetrics.classification.BinaryPrecision()(preds, labels)
        recall = torchmetrics.classification.BinaryRecall()(preds, labels)
        f1 = torchmetrics.classification.BinaryF1Score()(preds, labels)
        print(f"CKPT: {checkpoint}, acc: {acc}, recall: {recall}, precision: {precision}, F1: {f1}")
    else:
        print(f"CKPT: {checkpoint}")