In [None]:
import torch
import torchvision
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
import numpy as np
import cv2
import imgaug.augmenters as iaa
from dataset import CardiacDataset
import os

In [None]:
processed_path = "./processed_heart_detection" #Add path to the folder with processed images and subjects files
labels_path = "./heart_detection_labels.csv" #Add path to the labels file for the train/val/test dataset
ckpt_path = "./logs_heart/lightning_logs/version_0/checkpoints/" #Add the path to the folder with checkpoints

In [None]:
test_root_path = f"{processed_path}/test/"
test_subjects = f"{processed_path}/test_subjects.npy"
test_dataset = CardiacDataset(labels_path, test_subjects, test_root_path, None)

In [None]:
checkpoints = []
for checkpoint in os.listdir(ckpt_path):
    checkpoints.append(checkpoint)

In [None]:
print(checkpoints)

In [None]:
class CardiacDetectionModel1(pl.LightningModule):
    def __init__(self):
        super().__init__()
        
        self.model = torchvision.models.resnet18()
        self.model.conv1 = torch.nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        self.model.fc = torch.nn.Linear(in_features=512 ,out_features=4)
        self.loss_fn = torch.nn.MSELoss()
        
    def forward(self, data):
        return self.model(data)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
#The below function evaluates the model for all the checkpoints in the speficied directory
#Only the results for offset lower than 5 are printed

In [None]:
for checkpoint in checkpoints:
    checkpoint_path = (f"{ckpt_path}/{checkpoint}")
    model = CardiacDetectionModel1.load_from_checkpoint(checkpoint_path)
    model.eval()
    model.to(device);

    preds = []
    labels = []
    
    with torch.no_grad():
        for data, label in test_dataset:
            data = data.to(device).float().unsqueeze(0)
            pred = model(data)[0].cpu()
            preds.append(pred)
            labels.append(label)
        
        preds=torch.stack(preds)
        labels=torch.stack(labels)
        offset_all = abs(preds-labels).mean(0)
        offset = torch.mean(offset_all)
        if offset<5:
            print(f"CKPT: {checkpoint}, offset: {offset}, all: {offset_all}")