In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append("..")

import torch
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

import dinov2.utils.utils as dinov2_utils
from utils import (
    load_model, get_norm, get_dataloader, get_accuracy_logits,
    ImageTransform, DinoClassifier
)
from extended_datasets import NsclcRadiomicsEvalSex

In [None]:
path_to_run = "../runs/ctc_104x5x4/"
checkpoint_name = "training_399999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
model = DinoClassifier(
    feature_model,
    embed_dim=384*4,
    hidden_dim=2048,
    num_labels=2,
    device=device
)

In [None]:
mean_, std_ = get_norm(config)
img_processor = ImageTransform(224, mean_, std_)

nsclc_radiomics_kwargs = {
    "root": "../datasets/NSCLC-Radiomics/data",
    "extra": "../datasets/NSCLC-Radiomics/extra"
}

train_dataset = NsclcRadiomicsEvalSex(
    split="TRAIN",
    transform=img_processor,
    **nsclc_radiomics_kwargs
)
val_dataset = NsclcRadiomicsEvalSex(
    split="VAL",
    transform=img_processor,
    **nsclc_radiomics_kwargs
)
train_dataloader = get_dataloader(train_dataset, is_infinite=True)
val_dataloader = get_dataloader(val_dataset)

In [None]:
im, ta = train_dataset[0]
output = model(im.view(1,1,224, 224).to(device))
output

In [None]:
counts = [0, 0]
for index in range(len(train_dataset)):
    target = train_dataset.get_target(index)
    counts[target] += 1
cross_entropy_weights = torch.tensor([len(train_dataset)/x for x in counts]).to(device)
print("Adjusted weights for class imbalance:", cross_entropy_weights)

In [None]:
eval_interval = 10_000
max_iter = 100_000

criterion = torch.nn.CrossEntropyLoss(weight=cross_entropy_weights)
optimizer = torch.optim.SGD(
    model.parameters(), momentum=0.9, weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
iteration = 0
train_losses = []
val_losses = []
val_accuracies = []

while iteration < max_iter:
    
    model.train()
    train_loss_sum = 0.0
    train_tqdm = tqdm(range(1, eval_interval+1), desc=f"Training (Iter {iteration}/{max_iter}).", leave=False)
    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        optimizer.step()
        scheduler.step()
        
        train_loss_sum += loss.item()
        iteration += 1
        
        train_tqdm.set_postfix({"Loss": train_loss_sum / i})
        
        if iteration >= max_iter:
            break

    avg_train_loss = train_loss_sum / eval_interval
    train_losses.append(avg_train_loss)
    
    model.eval()
    val_loss_sum = 0.0
    val_accuracy_sum = 0.0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_dataloader, leave=False):
            outputs = model(inputs.to(device))
            loss = criterion(outputs, targets.to(device))
            val_loss_sum += loss.item()
            
            val_accuracy_sum += get_accuracy_logits(outputs, targets)
            
    avg_val_loss = val_loss_sum / len(val_dataloader)
    avg_val_accuracy = val_accuracy_sum / len(val_dataloader)
    val_losses.append(avg_val_loss)
    val_accuracies.append(avg_val_accuracy)
            
    print(f"Iteration: {iteration}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {avg_val_accuracy:.4f}")
