In [1]:
import pandas as pd
import numpy as np
import torch
import timm
from dataset_EfficientNetV2 import Dataset_EfficientNetV2


In [2]:
folder_path = '' #path to top folder containing patient directories. 
label_meta_data_path = '' #csv file with patient IDs and corresponding diagnostic labels.

model_path = '' # path to file containing trained model
predictions_output_csvfile = '' # path where a csv file will be saved with model predictions on testing dataset

subject_meta = pd.read_csv(label_meta_data_path)

is_ct_image=False  #set to False if MCI, set to True if CT
is_shape=False     #set to True if working with binary shape images. Should only be True if is_ct_image is also True

In [None]:

# Load testing dataset
dataset=Dataset_EfficientNetV2(folder_path, subject_meta, is_ct_image=is_ct_image, is_shape=is_shape)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:


val_loader = torch.utils.data.DataLoader(dataset, batch_size=64)


In [12]:

# Load trained model and run inference on testing dataset

model = timm.create_model('tf_efficientnetv2_s',in_chans=1,pretrained=True).to(device)
model_data = torch.load(model_path)
model.load_state_dict(model_data['model_state_dict'])
model.eval()

allPatients = {}
allLabels = {}

for val in dataset:
    patient = val[3]
    allPatients[patient] = np.zeros((1,3))
    allLabels[patient] = val[1]



with torch.no_grad():
    for images, labels, _,patient in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        
        _, predicted = torch.max(outputs.data, 1)        
        for idx, p in enumerate(patient):
            allPatients[p][0,predicted[idx]-1] += 1



In [None]:
from sklearn.metrics import confusion_matrix

y_true = []
y_pred = []
patient_name = []
for patient in allPatients:
    y_true.append(allLabels[patient]-1)
    y_pred.append(np.argmax(allPatients[patient]))
    patient_name.append(patient)

cm = confusion_matrix(y_true,y_pred)
print(cm)

In [None]:
patient_predictions = pd.DataFrame(
    {'patient': patient_name,
     'prediction': y_pred,
     'truth': y_true}
)

patient_predictions.to_csv(predictions_output_csvfile,index=False)