In [1]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch
from sklearn.metrics import roc_curve, auc
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from tqdm import tqdm

from models.sfcn import SFCN
from models.sfcn_daft import SFCN_DAFT

import sys

sys.path.append("..")
from utils.datasets import TorchDataset as TD

In [2]:
mode = "test"
dataloader = DataLoader(TD(f"data/{mode}"), batch_size=8, shuffle=False)

In [3]:
@torch.no_grad()
def test(model, dataloader, device="cuda"):
    model.eval()

    pd_true = []
    pd_pred = []
    study = []
    sex = []
    scanner = []
    for batch in tqdm(dataloader, desc="Test"):
        x = batch[0].to(device)
        pd_true += batch[1].tolist()
        sex += batch[2].tolist()
        study += batch[3].tolist()
        scanner += batch[4].tolist()

        # Forward pass with mixed precision
        with autocast():
            pd_pred += model(x).detach().cpu().tolist()

    return pd_true, pd_pred, sex, study, scanner

In [4]:
@torch.no_grad()
def test_daft(model, dataloader, device="cuda"):
    model.eval()

    pd_true = []
    pd_pred = []
    study = []
    sex = []
    scanner = []
    for batch in tqdm(dataloader, desc="Test"):
        x = batch[0].to(device)
        tabular = torch.stack(batch[2:5], dim=1).to(device, dtype=torch.float32)

        pd_true += batch[1].tolist()
        sex += batch[2].tolist()
        study += batch[3].tolist()
        scanner += batch[4].tolist()

        # Forward pass with mixed precision
        with autocast():
            pd_pred += model(x, tabular).detach().cpu().tolist()

    return pd_true, pd_pred, sex, study, scanner

In [5]:
def sigmoid(z):
    return 1 / (1 + np.exp(-z))


def accuracy(y_true, y_pred, threshold=0.5):
    y_pred = y_pred > threshold
    return np.mean(y_pred == y_true)

## SFCN

In [6]:
model = SFCN(output_dim=1, channel_number=[28, 58, 128, 256, 256, 64]).to("cuda")
checkpoint = torch.load("checkpoints/SFCN/best_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])

pd_true, pd_pred, sex, study, scanner = test(model, dataloader)
pd_pred = sigmoid(np.array(pd_pred).flatten())
pd_true = np.array(pd_true).flatten()

fpr, tpr, thresholds = roc_curve(pd_true, pd_pred)
roc_auc = auc(fpr, tpr)
acc = accuracy(pd_true, pd_pred, threshold=0.5)
roc_auc, acc

Test: 100%|██████████| 11/11 [00:03<00:00,  3.22it/s]


(0.840909090909091, 0.7738095238095238)

## SFCN-DAFT

In [7]:
model = SFCN_DAFT(output_dim=1, channel_number=[28, 58, 128, 256, 256, 64]).to("cuda")
checkpoint = torch.load("checkpoints/SFCN_DAFT/best_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])

pd_true, pd_pred, sex, study, scanner = test_daft(model, dataloader)
pd_pred = sigmoid(np.array(pd_pred).flatten())
pd_true = np.array(pd_true).flatten()

fpr, tpr, thresholds = roc_curve(pd_true, pd_pred)
roc_auc = auc(fpr, tpr)
acc = accuracy(pd_true, pd_pred, threshold=0.5)
roc_auc, acc

Test: 100%|██████████| 11/11 [00:01<00:00,  5.89it/s]


(0.8625, 0.75)