In [None]:
%cd ..
import warnings

warnings.filterwarnings("ignore")

import os
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from collections import deque

from evaluation.utils.finetune import load_model, binary_accuracy_logits, LinearClassifier
from evaluation.utils.segmentation import (
    sample_from_queues,
    process_batch,
    update_difficult_negatives,
    binary_mask_to_patch_labels,
    ImageTargetTransform
)
from evaluation.extended_datasets.lidc_idri import LidcIdriSplit, get_dataloader

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

In [None]:
path_to_run = "runs/base_103x4x5"
checkpoint_name = "training_69999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
print("Loaded model")

In [None]:
full_image_size = config.student.full_image_size
patch_size = config.student.patch_size
data_mean = -573.8
data_std = 461.3
channels = 4

print("Full image size:", full_image_size)

In [None]:
print("Num cpus:", os.cpu_count())

In [None]:
img_processor = ImageTargetTransform(full_image_size, data_mean, data_std)

dataset_kwargs = {
    "root_path": os.path.join(data_path, "dicoms"),
    "mask_path": os.path.join(project_path, "data/eval/LIDC-IDRI/masks"),
    "transform": img_processor,
    "max_workers": 4,
    "train_val_split": 0.95,
}

train_dataset = LidcIdriSplit(**dataset_kwargs, split="train")
val_dataset = LidcIdriSplit(**dataset_kwargs, split="val")
train_dataloader = get_dataloader(train_dataset, channels=4, split="train")
val_dataloader = get_dataloader(val_dataset, channels=4, split="val")

In [None]:
def show_embed_dim():
    test_images, test_targets = next(train_dataloader)
    unit_batch = test_images[0].view(1, channels, full_image_size, full_image_size)
    with torch.no_grad():
        outputs = feature_model(unit_batch.to(device))
    _, _, embed_dim = outputs[0][0].shape
    print("Embedding dimension:", embed_dim)
# print(show_embed_dim())

In [None]:
# test that the mask is correct by inspecting some examples
batch_img, batch_target = next(iter(train_dataloader))
batch_img.shape

In [None]:
def find_nodule(target):
    batches, chans, img_size, _ = target.shape
    for batch_idx in range(batches):
        if target[batch_idx].sum() > 0:
            return batch_idx
find_nodule(batch_target)

In [None]:
from evaluation.utils.segmentation import show_mask

batch_idx = 8
channel = 0

img_slice = batch_img[batch_idx][channel].numpy()
target_slice = batch_target[batch_idx][channel].numpy()

fig, axs = plt.subplots(1, 2, figsize=(12, 8))

plt.subplots_adjust(wspace=0.2)

img_with_mask = show_mask(img_slice, target_slice)

im = axs[0].imshow(batch_img[batch_idx][channel], cmap="gray")
axs[0].set_title("image")
axs[1].imshow(img_with_mask)
axs[1].set_title("highlighted node")

cbar = fig.colorbar(im, ax=axs, location="right", shrink=0.6)

plt.show()

In [None]:
embed_dim = 768
EMBED_DIM = embed_dim * 4
PATCH_SIZE = config.student.patch_size
BATCH_SIZE = 100

classifier_model = LinearClassifier(
    embed_dim=EMBED_DIM, hidden_dim=2048, num_labels=1
).to(device)

In [None]:
CACHE_MAX_SIZE = 1_000

positive_patch_cache = deque(maxlen=CACHE_MAX_SIZE)
negative_patch_cache = deque(maxlen=CACHE_MAX_SIZE)

iter_train_loader = iter(train_dataloader)

alpha = 0.95
iteration = 0
eval_interval = 500  # test on validation every this many iterations
max_iter = 2_000  # total number of train steps (eval happens inside this interval)

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
def add_to_cache(cache: deque, tokens: torch.Tensor):
    cache.extend(tokens.detach().cpu())

def compute_loss_and_backprop(
    outputs, labels, clip_value=10.0
):
    loss = criterion(outputs, labels)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), clip_value)
    optimizer.step()
    scheduler.step()
    return loss.item()

def compute_metrics(accuracy_sum, acc_breakdown):
    positives, true_positives, negatives, true_negatives = acc_breakdown
    avg_accuracy = accuracy_sum / len(val_dataloader)
    avg_p_accuracy = true_positives / positives
    avg_n_accuracy = true_negatives / negatives
    return avg_accuracy, avg_p_accuracy, avg_n_accuracy

def process_batch(
    inputs, targets, embed_dim, patch_size, device, use_n_blocks=4
):
    """Extract and process patch tokens and labels."""

    with torch.no_grad():
        x_tokens_list = feature_model(inputs.to(device))
    intermediate_output = x_tokens_list[-use_n_blocks:]
    patch_tokens = torch.cat(
        [patch_token for patch_token, _ in intermediate_output], dim=-1
    ).view(-1, embed_dim)

    patch_labels = binary_mask_to_patch_labels(targets.to(device), patch_size).view(-1)

    return patch_tokens, patch_labels

In [None]:
def train(
    iteration: int
) -> int:
    classifier_model.train()
    running_loss = 0.0
    train_tqdm = tqdm(range(1, eval_interval + 1), desc=f"Training", leave=True)

    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()

        patch_tokens, patch_labels = process_batch(
            feature_model, inputs, targets, EMBED_DIM, patch_size, device
        )

        masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
        masked_patch_tokens = patch_tokens[masked_patch_indices]
        if len(masked_patch_tokens) > 0:
            add_to_cache(positive_patch_cache, masked_patch_tokens)

        if len(positive_patch_cache) == 0:
            continue

        num_resample = min(len(positive_patch_cache) * 2, BATCH_SIZE)
        resampled_tokens, resampled_labels, _ = sample_from_queues(
            positive_patch_cache, negative_patch_cache, patch_tokens, patch_labels, num_resample, device
        )

        classifier_output = classifier_model(resampled_tokens)

        update_difficult_negatives(
            classifier_output, resampled_labels, resampled_tokens, negative_patch_cache
        )

        loss_value = compute_loss_and_backprop(
            classifier_output, resampled_labels
        )

        running_loss = (
            running_loss * alpha + (1 - alpha) * loss_value if i > 1 else loss_value
        )
        iteration += 1

        train_tqdm.set_postfix({"Loss": running_loss, "Cache": len(positive_patch_cache)})

    return iteration

In [None]:
def validation():
    classifier_model.eval()
    accuracy_sum, positives, negatives = 0.0, 0, 0
    true_pred_positives, true_pred_negatives = 0, 0

    with torch.no_grad():
        val_tqdm = tqdm(
            iter(val_dataloader),
            desc=f"Evaluation",
            leave=True,
        )
        for inputs, targets in val_tqdm:
            patch_tokens, patch_labels = process_batch(
                feature_model, inputs, targets, EMBED_DIM, patch_size, device
            )

            classifier_output = classifier_model(patch_tokens)

            accuracy, acc_breakdown = binary_accuracy_logits(
                classifier_output.view(-1), patch_labels
            )
            accuracy_sum += accuracy
            positives += acc_breakdown[0]
            true_pred_positives += acc_breakdown[1]
            negatives += acc_breakdown[2]
            true_pred_negatives += acc_breakdown[3]

    avg_accuracy = accuracy_sum / len(val_dataloader)
    avg_p_accuracy = true_pred_positives / positives
    avg_n_accuracy = true_pred_negatives / negatives
    return avg_accuracy, avg_p_accuracy, avg_n_accuracy

In [None]:
while iteration < max_iter:
    iteration = train(iteration)
    avg_accuracy, avg_p_accuracy, avg_n_accuracy = validation()
    print(
        f"Iteration: {iteration}, Overall Accuracy: {avg_accuracy:.4f}, Positives: {avg_p_accuracy}, Negatives: {avg_n_accuracy}"
    )

In [None]:
# visualize some image segmentations
demoiter = iter(val_dataloader)

In [None]:
images, targets = next(demoiter)
patch_labels = binary_mask_to_patch_labels(targets.to(device), PATCH_SIZE)
masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
masked_patch_indices

In [None]:
features = feature_model(images.to(device))
patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1)
classifier_output = classifier_model(patch_tokens)
classifier_output.shape

In [None]:
selected_batch = 19

original_image = images[selected_batch][0].numpy()
true_mask = patch_labels[selected_batch].view(16, 16).cpu().numpy()
predicted_mask = classifier_output[selected_batch].detach().view(16, 16).cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image, cmap="gray")
axes[0].set_title("Image")

axes[1].imshow(true_mask, cmap="gray")
axes[1].set_title("True mask")

axes[2].imshow(predicted_mask, cmap="gray")
axes[2].set_title("Predicted mask")

plt.tight_layout()
plt.show()

In [None]:
plt.imshow(predicted_mask > 0)
plt.show()