## Imports

In [48]:
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
import numpy as np
from torchvision import transforms

# # book keeping namings and code
# from settings import img_size, prototype_shape, num_classes, \
#                      prototype_activation_function, \
#                      add_on_layers_type, test_information, \
#                      num_test_examples, img_size, test_batch_size

from dataset_class import ECGImageDataset
import model_for_superclasses as model

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Add paths

In [52]:
saved_model_path = 'saved_models/vgg19/5/19nopushAUROC_0.8465.pth'
test_json_path = 'test-100.json'

## Create data loader for test set

In [45]:
# Define transformations
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    normalize,
])
# Function to create a subset of the dataset
def create_subset(dataset, num_examples):
    # Ensure num_examples doesn't exceed the dataset length
    num_examples = min(len(dataset), num_examples)
    indices = np.random.choice(len(dataset), num_examples, replace=False)
    subset = torch.utils.data.Subset(dataset, indices)
    return subset

# Initialize dataset and dataloader for testing
test_dataset = ECGImageDataset(test_information, transform=transform)

if num_test_examples is not None:
    test_subset = create_subset(test_dataset, num_test_examples)
else:
    test_subset = test_dataset

# Create data loader for the subset
test_loader = torch.utils.data.DataLoader(test_subset, batch_size=test_batch_size, shuffle=True, num_workers=4, pin_memory=False)

## Load the saved model

In [56]:
# construct the model
base_architecture = 'vgg19'
img_size = 224
num_classes = 11
num_prototypes_for_each_class = 32 
num_prototypes = num_classes * num_prototypes_for_each_class
prototype_shape = (num_prototypes, 128, 1, 1)
prototype_activation_function = 'log'
add_on_layers_type = 'regular'

ppnet = model.construct_PPNet(base_architecture=base_architecture,
                              pretrained=True, img_size=img_size,
                              prototype_shape=prototype_shape,
                              num_classes=num_classes,
                              prototype_activation_function=prototype_activation_function,
                              add_on_layers_type=add_on_layers_type)
ppnet = ppnet.to('cuda')

In [55]:
# Load the model
ppnet.load_state_dict(torch.load(saved_model_path))
ppnet.eval()

  ppnet.load_state_dict(torch.load(saved_model_path))


PPNet(
	features: VGG19, batch_norm=False,
	img_size: 224,
	prototype_shape: (352, 128, 1, 1),
	proto_layer_rf_info: [7, 32, 268, 16.0],
	num_classes: 11,
	epsilon: 0.0001
)

In [36]:
# torch.save(ppnet, 'vgg19-proto-auroc@0.8465.pth')

In [38]:
# ppnet = torch.load('vgg19-proto-auroc@0.8465.pth')
# ppnet.eval()

## Test the best model on test set again

In [32]:
# List of heart conditions corresponding to the labels
heart_conditions = ['NORM', 'Acute MI', 'Old MI', 'STTC', 'CD', 'HYP', 
                    'PAC', 'PVC', 'AFIB/AFL', 'TACHY', 'BRADY']

# Define a function to test the model and calculate AUROC scores
def test_model(test_loader, model):
    all_labels = []
    all_preds = []
    
    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Testing", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            
            # Get model predictions
            outputs = model(images)[0]
            
            # Store predictions and labels
            all_preds.append(outputs.cpu())
            all_labels.append(labels.cpu())
    
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    
    # Calculate overall AUROC
    overall_auroc = roc_auc_score(all_labels, all_preds, average='macro', multi_class='ovr')
    print(f"Overall AUROC: {overall_auroc:.4f}")
    
    # Calculate AUROC for each individual class
    for i, condition in enumerate(heart_conditions):
        class_auroc = roc_auc_score(all_labels[:, i], all_preds[:, i])
        print(f"AUROC for {condition}: {class_auroc:.4f}")

# Test the model
test_model(test_loader, ppnet)

                                                                                                                                                                                     

Overall AUROC: 0.8452
AUROC for NORM: 0.9426
AUROC for Acute MI: 0.6974
AUROC for Old MI: 0.9044
AUROC for STTC: 0.9119
AUROC for CD: 0.8932
AUROC for HYP: 0.8973
AUROC for PAC: 0.4952
AUROC for PVC: 0.8288
AUROC for AFIB/AFL: 0.9212
AUROC for TACHY: 0.9328
AUROC for BRADY: 0.8725
