In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append("..")

import numpy as np
import random
from collections import deque

import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import dinov2.utils.utils as dinov2_utils
from utils import (
    load_model, get_norm, get_dataloader, binary_accuracy_logits, binary_mask_to_patch_labels,
    ImageTargetTransform, LinearClassifier
)

from extended_datasets import LidcIdriNodules

In [None]:
path_to_run = "../runs/ctc_104x5x4/"
checkpoint_name = "training_399999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
classifier_model = LinearClassifier(
    embed_dim=384*4,
    hidden_dim=2048,
    num_labels=1
).to(device)

In [None]:
mean_, std_ = get_norm(config)
img_processor = ImageTargetTransform(224, mean_, std_)

lidc_idri_kwargs = {
    "root": "../datasets/LIDC-IDRI/data",
    "extra": "../datasets/LIDC-IDRI/extra"
}

train_dataset = LidcIdriNodules(
    split="TRAIN",
    transforms=img_processor,
    **lidc_idri_kwargs
)
val_dataset = LidcIdriNodules(
    split="VAL",
    transforms=img_processor,
    **lidc_idri_kwargs
)
train_dataloader = get_dataloader(train_dataset, is_infinite=True)
val_dataloader = get_dataloader(val_dataset)

In [None]:
eval_interval = 1_000
max_iter = 3_000

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(
    classifier_model.parameters(), momentum=0.9, weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
PATCH_SIZE = 14
BATCH_SIZE = 64
EMBED_DIM = 1536
cache_max_size = 2_000
alpha = 0.95
patch_embedding_cache = deque(maxlen=cache_max_size)

iteration = 0
while iteration < max_iter:
    
    classifier_model.train()
    running_loss = 0.0
    train_tqdm = tqdm(range(1, eval_interval+1), desc=f"Training", leave=False)
    
    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()
        
        with torch.no_grad():
            features = feature_model(inputs.to(device))
        patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1).view(-1, EMBED_DIM)
        patch_labels = binary_mask_to_patch_labels(targets.to(device), PATCH_SIZE).view(-1)

        masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
        masked_patch_tokens = patch_tokens[masked_patch_indices]

        for patch in masked_patch_tokens:
            patch_embedding_cache.append(patch.detach().cpu())
            
        if len(patch_embedding_cache) > 0:
            num_resample = min(len(patch_embedding_cache), PATCH_SIZE ** 2 * BATCH_SIZE // 2)

            cache_indices = random.sample(range(len(patch_embedding_cache)), num_resample)
            cache_tokens = [patch_embedding_cache[idx].to(device) for idx in cache_indices]
            cache_labels = [torch.tensor(1.0).view(1).to(device) for _ in range(num_resample)]

            new_indices = random.sample(range(len(patch_labels)), num_resample)
            new_tokens = [patch_tokens[idx] for idx in new_indices]
            new_labels = [patch_labels[idx].view(1) for idx in new_indices]
            
            resampled_tokens = torch.stack(cache_tokens + new_tokens, dim=0)
            resampled_labels = torch.stack(cache_labels + new_labels, dim=0)
            
        else:
            resampled_tokens = patch_tokens
            resampled_labels = patch_labels.view(-1, 1)
        
        classifier_output = classifier_model(resampled_tokens)
        
        loss = criterion(classifier_output, resampled_labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        running_loss = running_loss * alpha + (1-alpha) * loss.item()
        iteration += 1
        
        train_tqdm.set_postfix({"Loss": running_loss, "Cache": len(patch_embedding_cache)})
    
    classifier_model.eval()
    accuracy_sum = 0.0
    positives = 0
    negatives = 0
    true_pred_positives = 0
    true_pred_negatives = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_dataloader, desc=f"Evaluation", leave=False):
            features = feature_model(inputs.to(device))
            patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1).view(-1, EMBED_DIM)
            patch_labels = binary_mask_to_patch_labels(targets.to(device), PATCH_SIZE).view(-1, 1)
            
            classifier_output = classifier_model(patch_tokens)
            
            accuracy, acc_breakdown = binary_accuracy_logits(classifier_output, patch_labels)
            accuracy_sum += accuracy
            positives += acc_breakdown[0]
            true_pred_positives += acc_breakdown[1]
            negatives += acc_breakdown[2]
            true_pred_negatives += acc_breakdown[3]
            
    avg_accuracy = accuracy_sum / len(val_dataloader)
    avg_p_accuracy = true_pred_positives / positives
    avg_n_accuracy = true_pred_negatives / negatives
    
    print(f"Iteration: {iteration}, Overall Accuracy: {avg_accuracy:.4f}, Positives: {avg_p_accuracy}, Negatives: {avg_n_accuracy}")


In [None]:
# visualize some image segmentations
demoiter = iter(val_dataloader)

In [None]:
images, targets = next(demoiter)
features = feature_model(images.to(device))
patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1)
patch_labels = binary_mask_to_patch_labels(targets.to(device), PATCH_SIZE)
patch_labels.sum()

In [None]:
masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
masked_patch_indices

In [None]:
classifier_output = classifier_model(patch_tokens)
classifier_output.shape

In [None]:
selected_batch = 18

original_image = images[selected_batch][0].numpy()
true_mask = patch_labels[selected_batch].view(16,16).cpu().numpy()
predicted_mask = classifier_output[selected_batch].detach().view(16,16).cpu().numpy()

fig, axes = plt.subplots(1, 3, figsize=(15, 5))

axes[0].imshow(original_image, cmap='gray')
axes[0].set_title('Image')

axes[1].imshow(true_mask, cmap='gray')
axes[1].set_title('True mask')

axes[2].imshow(predicted_mask, cmap='gray')
axes[2].set_title('Predicted mask')

plt.tight_layout()
plt.show()