In [None]:
%cd ..
import warnings

warnings.filterwarnings("ignore")

import os
import numpy as np
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from dotenv import load_dotenv
from collections import deque

from evaluation.utils.finetune import load_model, binary_accuracy_logits, extract_class_tokens
from evaluation.utils.classification import ImageTransform, AggregateClassTokens
from evaluation.extended_datasets.deeprdt_lung import DeepRDTSplit, get_dataloader

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

In [None]:
path_to_run = "runs/base_103x4x5"
checkpoint_name = "training_69999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
print("Loaded model")

In [None]:
full_image_size = config.student.full_image_size
patch_size = config.student.patch_size
data_mean = -573.8
data_std = 461.3
channels = 4

print("Full image size:", full_image_size)

In [None]:
print("Num cpus:", os.cpu_count())

In [None]:
img_processor = ImageTransform(full_image_size, data_mean, data_std)

dataset_kwargs = {
    "root_path": os.path.join(data_path, "dicoms"),
    "metadata_path": os.path.join(data_path, "dicoms/DeepRDT-lung/metadata_lung_oldPat.csv"),
    "transform": img_processor,
    "max_workers": 4,
    "train_val_split": 0.95,
}

train_positives = DeepRDTSplit(**dataset_kwargs, split="train", labels=True)
train_negatives = DeepRDTSplit(**dataset_kwargs, split="train", labels=False)

train_positives_dataloader = get_dataloader(train_positives, channels=4, split="train")
train_negatives_dataloader = get_dataloader(train_negatives, channels=4, split="train")

In [None]:
def show_embed_dim():
    test_images, test_targets = next(train_dataloader)
    unit_batch = test_images[0].view(1, channels, full_image_size, full_image_size)
    with torch.no_grad():
        outputs = feature_model(unit_batch.to(device))
    _, _, embed_dim = outputs[0][0].shape
    print("Embedding dimension:", embed_dim)
# print(show_embed_dim())

In [None]:
batch_img, batch_target = next(iter(train_dataloader))
batch_img.shape

In [None]:
# 1-Completa, 2-Parcial, 3-Estable, 4-Progresion

In [None]:
embed_dim = 768
EMBED_DIM = embed_dim * 4
PATCH_SIZE = config.student.patch_size
BATCH_SIZE = 100

classifier_model = AggregateClassTokens(
    embed_dim=EMBED_DIM, num_labels=1
).to(device)

In [None]:
from sklearn.metrics import classification_report

def compute_metrics(predictions, labels):
    return classification_report(labels, predictions, target_names=['No Response', 'Response'])

In [None]:
class Trainer:
    def __init__(self):
        self.dataloader = dataloader
        
        self.optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0)
        self.criterion = torch.nn.BCEWithLogitsLoss()

        self.device = device
        self.alpha = 0.95
        self.eval_interval = 20

        self.positives = []
        self.negatives = []

    def train(self, iteration: int=0) -> int:

        self.model.train()
        running_loss = 0.0

        self.optimizer.zero_grad()
        
        iter_train_loader = iter(self.dataloader)
        epoch_tqdm = tqdm(range(1, len(self.dataloader) + 1), desc=f"Training", leave=True)
    
        for i in epoch_tqdm:
            inputs, label = next(iter_train_loader)
    
            with torch.no_grad():
                x_tokens_list = feature_model(inputs.to(device))
            class_tokens = extract_class_tokens(x_tokens_list)    
            classifier_output = classifier_model(resampled_tokens)
    
            loss = criterion(outputs, labels)
            loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
        optimizer.step()
        scheduler.step()

        iteration += 1
    
        return iteration

In [None]:
trainer = Trainer()
trainer.train()