In [None]:
%cd ..
import os
import torch
import pandas as pd
import numpy as np
from dotenv import load_dotenv
import matplotlib.pyplot as plt

from typing import List, Dict, Tuple

from evaluation.utils.networks import FullScanPatchPredictor
from evaluation.extended_datasets import CachedEmbeddings
from evaluation.tasks.ct_rate.datasets import CT_RATE
from evaluation.utils.dataset import collate_sequences

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")
device = torch.device("cuda")

experiment_name = "ct_rate_cardiomegaly_patch"
label = "Cardiomegaly"

results_path = os.path.join(project_path, "runs", experiment_name, "results", label)
model_path = os.path.join(
    project_path, "runs", experiment_name, "results", label, "test/model/model.pth"
)

In [None]:
def get_logits(
    result_path: str, valid_name: int, epoch: int
) -> Tuple[np.ndarray, np.ndarray]:
    """
    Folds and epochs are 0-indexed.
    """
    epoch_predictions_path = os.path.join(
        result_path, valid_name, "predictions", f"epoch_{epoch:02d}.csv"
    )
    epoch_predictions = pd.read_csv(epoch_predictions_path)
    map_ids = epoch_predictions["map_id"].values
    logits = epoch_predictions["logits"].values
    labels = epoch_predictions["labels"].values

    return map_ids, logits, labels


map_ids, logits, labels = get_logits(results_path, "test", 0)

In [None]:
positives = np.where(logits > 0.3)[0]
negatives = np.where(logits < -0.1)[0]
print(f"Positives: {map_ids[positives]}")
print(f"Negatives: {map_ids[negatives]}")

In [None]:
embed_dim = 768
hidden_dim = 512
patch_resample_dim = 16

classifier_model = FullScanPatchPredictor(
    embed_dim,
    hidden_dim,
    num_labels=1,
    patch_resample_dim=patch_resample_dim,
)
state_dict = torch.load(model_path, map_location=device)

classifier_model.load_state_dict(
    {k.replace("module.", ""): v for k, v in state_dict.items()}
)
classifier_model.to(device)
classifier_model.eval()

In [None]:
run_name = "base10pat"
checkpoint_name = "training_99999"

batch_size = 1
num_workers = 4

embeddings_path = os.path.join(
    project_path,
    "evaluation/cache/CT-RATE_valid_eval",
    run_name,
    checkpoint_name,
)

embeddings_provider = CachedEmbeddings(embeddings_path, select_feature="patch")

metadata_path = os.path.join(
    data_path, "niftis/CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv"
)

embeddings_dataset = CT_RATE(
    embeddings_provider,
    metadata_path,
    label,
)

In [None]:
from evaluation.utils.finetune import get_config, ImageTransformResampleSlices
from evaluation.extended_datasets import NiftiFullVolumeEval

path_to_run = os.path.join(project_path, "runs", run_name)

config = get_config(path_to_run)
full_image_size = config.student.full_image_size
data_mean = -573.8
data_std = 461.3
channels = config.student.channels

img_processor = ImageTransformResampleSlices(
    full_image_size, data_mean, data_std, channels=channels
)

db_params = {
    "root_path": os.path.join(data_path, "niftis/CT-RATE"),
    "dataset_name": "CT-RATE_valid_eval",
    "channels": channels,
    "transform": img_processor,
}

image_dataset = NiftiFullVolumeEval(**db_params)


def get_image_from_map_id(map_id: str) -> Tuple[np.ndarray, Dict[str, str]]:
    """
    Get the image and metadata from the map_id.
    """
    image_index = image_dataset.get_index_from_map_id(map_id)
    image, _ = image_dataset[image_index]

    embedding_index = embeddings_dataset.map_ids.index(map_id)

    collated = collate_sequences([embeddings_dataset[embedding_index]])

    return image, collated

In [None]:
map_id = "valid_1008_a_2"
image, (map_id, embeddings, mask, labels) = get_image_from_map_id(map_id)

embeddings = embeddings.to(device)
mask = mask.to(device)
with torch.no_grad():
    logits, attention_map = classifier_model(embeddings, mask=mask)

logits = logits.cpu().flatten().numpy()[0]
labels = labels.float().flatten().numpy()[0]
attention_maps_cpu = attention_map.cpu().numpy()[0]
image = image.numpy()

print(f"Map ID: {map_id}")
print(f"Logit: {logits}")
print(f"Label: {labels}")
print(f"Attention map shape: {attention_maps_cpu.shape}")
print(f"Image shape: {image.shape}")

patch_size = int(image.shape[2] / attention_maps_cpu.shape[2] ** 0.5)
print(f"Patch size: {patch_size}")
axial_dim = image.shape[0]
print(f"Axial dim: {axial_dim}")

In [None]:
attn_min = attention_maps_cpu.min()
attn_max = attention_maps_cpu.max()
print(f"Attention map range: {attn_min}, {attn_max}")

In [None]:
rot = -1
slice_num = 10
chan = 4

attn_map = np.mean(attention_maps_cpu, axis=1)[slice_num].reshape(36, 36)
attn_map = np.rot90(attn_map, k=rot)
img_slice = image[slice_num, chan]
img_slice = np.rot90(img_slice, k=rot)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(attn_map, cmap="inferno")
ax1.set_title("Mean Attention Map")
ax1.axis("off")

ax2.imshow(img_slice, cmap="gray")
ax2.set_title("Image")
ax2.axis("off")

plt.tight_layout()
plt.show()

In [None]:
rot = -1
attn_maps = attention_maps_cpu[slice_num]
img_slice = image[slice_num, chan]

fig, axs = plt.subplots(4, 4, figsize=(10, 10))

axs = axs.flatten()

for attn_head, ax in enumerate(axs):
    attn_map = attn_maps[attn_head].reshape(36, 36)
    attn_map = np.rot90(attn_map, k=rot)

    ax.imshow(attn_map, cmap="inferno", vmin=attn_min, vmax=attn_max)
    ax.set_title(f"Attention Head: {attn_head}")
    ax.axis("off")

plt.show()

In [None]:
rot = 2
row_number = 18

attn_map = attention_maps_cpu.reshape(axial_dim, -1, 36, 36)[:, :, row_number, :]
attn_map = np.mean(attn_map, axis=1)
attn_map = np.rot90(attn_map, k=rot)

img_slice = image.reshape(-1, full_image_size, full_image_size)
img_slice = img_slice[:, row_number * patch_size, :]
img_slice = np.rot90(img_slice, k=rot)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(attn_map, cmap="inferno")
ax1.set_title("Mean Attention Map")
ax1.axis("off")

ax2.imshow(img_slice, cmap="gray", aspect=14 / 10)
ax2.set_title("Image")
ax2.axis("off")

plt.tight_layout()
plt.show()

In [None]:
rot = 2
col_number = 16

attn_map = attention_maps_cpu.reshape(axial_dim, -1, 36, 36)[:, :, :, col_number]
attn_map = np.mean(attn_map, axis=1)
attn_map = np.rot90(attn_map, k=rot)

img_slice = image.reshape(-1, full_image_size, full_image_size)
img_slice = img_slice[:, :, col_number * patch_size]
img_slice = np.rot90(img_slice, k=rot)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))

ax1.imshow(attn_map, cmap="inferno")
ax1.set_title("Mean Attention Map")
ax1.axis("off")

ax2.imshow(img_slice, cmap="gray", aspect=14 / 10)
ax2.set_title("Image")
ax2.axis("off")

plt.tight_layout()
plt.show()