In [None]:
%cd ..
import warnings

warnings.filterwarnings("ignore")

import os
import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from collections import deque

from eval.utils import (
    load_model,
    binary_accuracy_logits,
    ImageTargetTransform,
    LinearClassifier,
)

from eval.segmentation import sample_from_queues, process_batch, update_difficult_negatives
from eval.extended_datasets import LidcIdriNodules, get_lidcidri_loader

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

In [None]:
path_to_run = "runs/base_103x4x5"
checkpoint_name = "training_69999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
print("Loaded model")

In [None]:
full_image_size = config.student.full_image_size
patch_size = config.student.patch_size
data_mean = -573.8
data_std = 461.3
channels = 4

print("Full image size:", full_image_size)

In [None]:
img_processor = ImageTargetTransform(full_image_size, data_mean, data_std)

dataset_kwargs = {
    "root_path": os.path.join(data_path, "dicoms"),
    "mask_path": os.path.join(project_path, "data/eval/LIDC-IDRI/masks"),
    "transform": img_processor
}

train_dataset = LidcIdriNodules(**dataset_kwargs)
train_dataloader = get_lidcidri_loader(train_dataset, channels=4)

test_images, test_targets = next(train_dataloader)
unit_batch = test_images[0].view(1, channels, full_image_size, full_image_size)
with torch.no_grad():
    outputs = feature_model(unit_batch.to(device))
_, _, embed_dim = outputs[0][0].shape
print("Embedding dimension:", embed_dim)

In [None]:
embed_dim = 768

In [None]:
EMBED_DIM = embed_dim * 4
PATCH_SIZE = config.student.patch_size
BATCH_SIZE = 100

classifier_model = LinearClassifier(
    embed_dim=EMBED_DIM, hidden_dim=2048, num_labels=1
).to(device)

In [None]:
eval_interval = 1_000
max_iter = 2_000

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
cache_max_size = 2_000
alpha = 0.95
patch_embedding_cache = deque(maxlen=cache_max_size)
negative_patch_queue = deque(maxlen=cache_max_size)

iteration = 0
while iteration < max_iter:
    classifier_model.train()
    running_loss = 0.0
    train_tqdm = tqdm(range(1, eval_interval + 1), desc=f"Training", leave=False)

    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()

        patch_tokens, patch_labels = process_batch(
            feature_model, inputs, targets, embed_dim, patch_size, device
        )

        masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
        masked_patch_tokens = patch_tokens[masked_patch_indices]

        patch_embedding_cache.extend(
            patch.detach().cpu() for patch in masked_patch_tokens
        )

        negative_patch_indices = (patch_labels == 0).nonzero(as_tuple=True)
        negative_patch_tokens = patch_tokens[negative_patch_indices]

        negative_patch_queue.extend(
            patch.detach().cpu() for patch in negative_patch_tokens
        )

        if len(patch_embedding_cache) == 0:
            continue
            
        num_resample = min(
            len(patch_embedding_cache)*2,
            BATCH_SIZE
        )
        resampled_tokens, resampled_labels, neg_indices_from_queue = (
            sample_from_queues(
                patch_embedding_cache,
                negative_patch_queue,
                patch_tokens,
                patch_labels,
                num_resample,
                device,
            )
        )

        classifier_output = classifier_model(resampled_tokens)
        loss = criterion(classifier_output, resampled_labels)

        update_difficult_negatives(
            classifier_output,
            resampled_labels,
            resampled_tokens,
            negative_patch_queue,
            neg_indices_from_queue,
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), args.clip)
        optimizer.step()
        scheduler.step()

        running_loss = running_loss * alpha + (1 - alpha) * loss.item()
        iteration += 1

        train_tqdm.set_postfix(
            {"Loss": running_loss, "Cache": len(patch_embedding_cache)}
        )

    classifier_model.eval()
    accuracy_sum = 0.0
    positives = 0
    negatives = 0
    true_pred_positives = 0
    true_pred_negatives = 0

    with torch.no_grad():
        for inputs, targets in tqdm(val_dataloader, desc=f"Evaluation", leave=False):
            features = feature_model(inputs.to(device))
            patch_tokens = torch.cat(
                [patch_token for patch_token, _ in features], dim=-1
            ).view(-1, EMBED_DIM)
            patch_labels = binary_mask_to_patch_labels(
                targets.to(device), PATCH_SIZE
            ).view(-1, 1)

            classifier_output = classifier_model(patch_tokens)

            accuracy, acc_breakdown = binary_accuracy_logits(
                classifier_output, patch_labels
            )
            accuracy_sum += accuracy
            positives += acc_breakdown[0]
            true_pred_positives += acc_breakdown[1]
            negatives += acc_breakdown[2]
            true_pred_negatives += acc_breakdown[3]

    avg_accuracy = accuracy_sum / len(val_dataloader)
    avg_p_accuracy = true_pred_positives / positives
    avg_n_accuracy = true_pred_negatives / negatives

    print(
        f"Iteration: {iteration}, Overall Accuracy: {avg_accuracy:.4f}, Positives: {avg_p_accuracy}, Negatives: {avg_n_accuracy}"
    )

In [None]:
# visualize some image segmentations
demoiter = iter(val_dataloader)

In [None]:
images, targets = next(demoiter)
patch_labels = binary_mask_to_patch_labels(targets.to(device), PATCH_SIZE)
masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
masked_patch_indices

In [None]:
features = feature_model(images.to(device))
patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1)
classifier_output = classifier_model(patch_tokens)
classifier_output.shape

In [None]:
selected_batch = 19

original_image = images[selected_batch][0].numpy()
true_mask = patch_labels[selected_batch].view(16, 16).cpu().numpy()
predicted_mask = classifier_output[selected_batch].detach().view(16, 16).cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image, cmap="gray")
axes[0].set_title("Image")

axes[1].imshow(true_mask, cmap="gray")
axes[1].set_title("True mask")

axes[2].imshow(predicted_mask, cmap="gray")
axes[2].set_title("Predicted mask")

plt.tight_layout()
plt.show()

In [None]:
plt.imshow(predicted_mask > 0)
plt.show()