In [None]:
import os
import torch
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple

from evaluation.utils.networks import FullScanPatchPredictor
from evaluation.extended_datasets import CachedEmbeddings
from evaluation.tasks.ct_rate.datasets import CT_RATE
from evaluation.utils.dataset import collate_sequences

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")
device = torch.device("cuda")

experiment_name = "test_cardiomegaly_cls"
label = "Cardiomegaly"
model_path = os.path.join(project_path, "runs", experiment_name, "results", label, "models/model.pth")

In [None]:
embed_dim = 768
hidden_dim = 512
patch_resample_dim = 16

classifier_model = FullScanPatchPredictor(
    embed_dim,
    hidden_dim,
    num_labels=1,
    patch_resample_dim=patch_resample_dim,
)
classifier_model.load_state_dict(torch.load(model_path, map_location=device))
classifier_model.to(device)

In [None]:
run_name = "base10pat"
checkpoint_name = "training_99999"

batch_size = 1
num_workers = 4

embeddings_path = os.path.join(
    project_path,
    "evaluation/cache/CT-RATE_valid_eval",
    run_name,
    checkpoint_name,
)

embeddings_provider = CachedEmbeddings(
    embeddings_path, select_feature="patch"
)

metadata_path = os.path.join(
    data_path, "niftis/CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv"
)

embeddings_dataset = CT_RATE(
    embeddings_provider,
    metadata_path,
    label,
)

dataloader = torch.utils.data.DataLoader(
    embeddings_dataset,
    batch_size=batch_size,
    num_workers=num_workers,
    collate_fn=collate_sequences,
    shuffle=False,
    pin_memory=True,
)

In [None]:
from evaluation.utils.finetune import get_config, ImageTransformResampleSlices
from evaluation.extended_datasets import NiftiFullVolumeEval

path_to_run = os.path.join(project_path, "runs", run_name)

config = get_config(path_to_run)
full_image_size = config.student.full_image_size
data_mean = -573.8
data_std = 461.3
channels = config.student.channels

img_processor = ImageTransformResampleSlices(
    full_image_size, data_mean, data_std, channels=channels
)

db_params = {
        "root_path": os.path.join(data_path, "niftis/CT-RATE"),
        "dataset_name": "CT-RATE_valid_eval",
        "channels": channels,
        "transform": img_processor,
    }

image_dataset = NiftiFullVolumeEval(**db_params)

def get_image_from_map_id(map_id: str) -> Tuple[np.ndarray, Dict[str, str]]:
    """
    Get the image and metadata from the map_id.
    """
    dataset_index = image_dataset.get_index_from_map_id(map_id)
    image, _ = image_dataset[dataset_index]
    return image

In [None]:
dataloader_iterator = iter(dataloader)
    
map_ids, embeddings, mask, labels = next(dataloader_iterator)

embeddings = embeddings.to(device)
mask = mask.to(device)
with torch.no_grad():
    logits, attention_map = classifier_model(embeddings, mask=mask)

map_id = map_ids[0]
logits = logits.cpu().flatten().numpy()[0]
labels = labels.float().flatten().numpy()[0]
attention_maps_cpu = attention_map.cpu().numpy()[0]
image = get_image_from_map_id(map_id).numpy()

print(f"Map ID: {map_id}")
print(f"Logit: {logits}")
print(f"Label: {labels}")
print(f"Attention map shape: {attention_maps_cpu.shape}")
print(f"Image shape: {image.shape}")