In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append("..")

import random
from collections import deque

import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import dinov2.utils.utils as dinov2_utils
from utils import (
    load_model, get_norm, get_dataloader, get_accuracy_logits, binary_mask_to_patch_labels,
    ImageTargetTransform, LinearClassifier
)

from extended_datasets import LidcIdriNodules

In [None]:
path_to_run = "../runs/ctc_104x5x4/"
checkpoint_name = "training_399999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
classifier_model = LinearClassifier(
    embed_dim=384*4,
    hidden_dim=2048,
    num_labels=2
).to(device)

In [None]:
mean_, std_ = get_norm(config)
img_processor = ImageTargetTransform(224, mean_, std_)

lidc_idri_kwargs = {
    "root": "../datasets/LIDC-IDRI/data",
    "extra": "../datasets/LIDC-IDRI/extra"
}

train_dataset = LidcIdriNodules(
    split="TRAIN",
    transforms=img_processor,
    **lidc_idri_kwargs
)
val_dataset = LidcIdriNodules(
    split="VAL",
    transforms=img_processor,
    **lidc_idri_kwargs
)
train_dataloader = get_dataloader(train_dataset, is_infinite=True)
val_dataloader = get_dataloader(val_dataset)

In [None]:
eval_interval = 10_000
max_iter = 100_000

criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(
    model.parameters(), momentum=0.9, weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
PATCH_SIZE = 14
cache_max_size = 1000
patch_embedding_cache = deque(maxlen=cache_max_size)

iteration = 0
train_losses = []
val_losses = []
val_accuracies = []

while iteration < max_iter:
    
    model.train()
    train_loss_sum = 0.0
    train_tqdm = tqdm(range(1, eval_interval+1), desc=f"Training (Iter {iteration}/{max_iter}).", leave=False)
    
    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()
        
        features = feature_model(inputs.to(device)) # (block_num, ((batch_size, patch_num, embed_dim), (batch_size, embed_dim)))
        patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1) # (batch_size, patch_num, embed_dim*block_num)
        
        patch_labels = get_patch_labels(targets.to(device), PATCH_SIZE)
        
        masked_patch_indices = (patch_labels > 0).nonzero(as_tuple=True)
        masked_patch_tokens = patch_tokens[masked_patch_indices]
        
        for patch in masked_patch_tokens:
            patch_embedding_cache.append((patch, patch_labels[masked_patch_indices]))
            
        if len(patch_embedding_cache) > 0:
            num_resample = min(len(patch_embedding_cache), batch_size // 2)
            resampled_patches, resampled_labels = zip(*random.sample(patch_embedding_cache, num_resample))
            patch_tokens = torch.cat([patch_tokens, torch.stack(resampled_patches)], dim=0)
            patch_labels = torch.cat([patch_labels, torch.stack(resampled_labels)], dim=0)
        
        classifier_output = classifier_model(patch_tokens)
        
        loss = criterion(classifier_output.squeeze(), patch_labels.flatten())
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        train_loss_sum += loss.item()
        iteration += 1
        
        train_tqdm.set_postfix({"Loss": train_loss_sum / i})
        
        if iteration >= max_iter:
            break

    avg_train_loss = train_loss_sum / eval_interval
    train_losses.append(avg_train_loss)
    
    model.eval()
    val_loss_sum = 0.0
    val_accuracy_sum = 0.0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_dataloader, leave=False):
            features = feature_model(inputs.to(device)) # (block_num, ((batch_size, patch_num, embed_dim), (batch_size, embed_dim)))
            patch_tokens = torch.cat([patch_token for patch_token, _ in features], dim=-1) # (batch_size, patch_num, embed_dim*block_num)
            patch_labels = get_patch_labels(targets.to(device), PATCH_SIZE)
            
            classifier_output = classifier_model(patch_tokens)
       
            loss = criterion(classifier_output.squeeze(), patch_labels.flatten())
            val_loss_sum += loss.item()
            
            val_accuracy_sum += get_accuracy_logits(classifier_output, patch_labels)
            
    avg_val_loss = val_loss_sum / len(val_dataloader)
    avg_val_accuracy = val_accuracy_sum / len(val_dataloader)
    val_losses.append(avg_val_loss)
    val_accuracies.append(avg_val_accuracy)
            
    print(f"Iteration: {iteration}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {avg_val_accuracy:.4f}")
