In [None]:
import warnings
warnings.filterwarnings('ignore')

import sys
sys.path.append("..")

import torch
from tqdm.notebook import tqdm

import dinov2.utils.utils as dinov2_utils
from utils import (
    load_model, get_dataloader, multiclass_accuracy_logits,
    ImageTransform, DinoClassifier, 
)
from extended_datasets import NsclcRadiogenomicsEvalSmokers

In [None]:
path_to_run = "../runs/base_112x6x3/"
checkpoint_name = "training_149999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
print("Loaded model")

In [None]:
full_image_size = config.student.full_image_size
data_mean, data_std = config.norm.values()
print("Full image size:", full_image_size)
print(f"mean: {data_mean}, std: {data_std}")

In [None]:
img_processor = ImageTransform(full_image_size, data_mean, data_std)

dataset_kwargs = {
    "root": "../datasets/NSCLC-Radiogenomics/data",
    "extra": "../datasets/NSCLC-Radiogenomics/extra"
}

train_dataset = NsclcRadiogenomicsEvalSmokers(
    split="TRAIN",
    transform=img_processor,
    **dataset_kwargs
)
val_dataset = NsclcRadiogenomicsEvalSmokers(
    split="VAL",
    transform=img_processor,
    **dataset_kwargs
)
train_dataloader = get_dataloader(train_dataset, is_infinite=True)
val_dataloader = get_dataloader(val_dataset)

In [None]:
im, ta = train_dataset[0]
with torch.no_grad():
    outputs = feature_model(im.view(1,1,504,504).to(device))
_, _, embed_dim = outputs[0][0].shape
print("Embedding dimension:", embed_dim)

In [None]:
model = DinoClassifier(
    feature_model,
    embed_dim=embed_dim*4,
    hidden_dim=4096,
    num_labels=2,
    device=device
)

In [None]:
counts = [0, 0]
for index in range(len(train_dataset)):
    target = train_dataset.get_target(index)
    counts[target] += 1
cross_entropy_weights = torch.tensor([len(train_dataset)/x for x in counts]).to(device)
print("Adjusted weights for class imbalance:", cross_entropy_weights)

In [None]:
eval_interval = 1_000
max_iter = 5_000

criterion = torch.nn.CrossEntropyLoss(weight=cross_entropy_weights)
optimizer = torch.optim.SGD(
    model.parameters(), momentum=0.9, weight_decay=0
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)

In [None]:
iteration = 0
while iteration < max_iter:
    model.train()
    running_loss = 0.0
    alpha = 0.99
    train_tqdm = tqdm(range(eval_interval), desc=f"Training", leave=False)
    for i in train_tqdm:
        inputs, targets = next(train_dataloader)
        optimizer.zero_grad()
        outputs = model(inputs.to(device))
        loss = criterion(outputs, targets.to(device))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        optimizer.step()
        scheduler.step()
        
        running_loss = alpha * running_loss + (1 - alpha) * loss.item()
        train_tqdm.set_postfix({"Loss": running_loss})
        
        iteration += 1
    
    model.eval()
    accuracy_sum = 0.0
    with torch.no_grad():
        for inputs, targets in tqdm(val_dataloader, desc=f"Evaluation", leave=False):
            outputs = model(inputs.to(device))
            accuracy_sum += multiclass_accuracy_logits(outputs, targets)
    
    avg_accuracy = accuracy_sum / len(val_dataloader)
    
    print(f"Iteration: {iteration}, Training Loss: {running_loss:.4f}, Validation Accuracy: {avg_accuracy:.4f}")


In [None]:
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score
import matplotlib.pyplot as plt

predictions = []
labels = []
with torch.no_grad():
    for inputs, targets in tqdm(val_dataloader):
        outputs = model(inputs.to(device))
        predictions += [x for x in outputs.detach().cpu()]
        labels += [x for x in targets]

classes = [0, 1]
class_labels = ["Nonsmoker", "Current or Former"]

y_pred = torch.stack(predictions).numpy()
y_true = torch.stack(labels).numpy()
y_true_binarized = label_binarize(y_true, classes=[0, 1, -1])

fpr = dict()
tpr = dict()
roc_auc = dict()
for i in classes:
    fpr[i], tpr[i], _ = roc_curve(y_true_binarized[:, i], y_pred[:, i])
    roc_auc[i] = auc(fpr[i], tpr[i])

plt.figure()
colors = ['red', 'blue']

for i, color in zip(classes, colors):
    plt.plot(fpr[i], tpr[i], color=color, lw=2,
             label=f'{class_labels[i]} (Area = {roc_auc[i]:0.2f})')

plt.plot([0, 1], [0, 1], 'k--', lw=2)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.legend(loc="lower right")
plt.show()