In [None]:
%cd ..
import warnings

warnings.filterwarnings("ignore")

import os
import torch
from tqdm import tqdm
from dotenv import load_dotenv

from evaluation.utils.finetune import load_model, extract_class_tokens
from evaluation.utils.classification import ImageTransform, AggregateClassTokens
from evaluation.extended_datasets.deeprdt_lung import get_dataloaders

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

In [None]:
path_to_run = "runs/base_103x4x5"
checkpoint_name = "training_69999"
device = torch.device("cuda:0")

feature_model, config = load_model(path_to_run, checkpoint_name, device)
print("Loaded model")

In [None]:
full_image_size = config.student.full_image_size
patch_size = config.student.patch_size
data_mean = -573.8
data_std = 461.3
channels = 4

print("Full image size:", full_image_size)

In [None]:
print("Num cpus:", os.cpu_count())

In [None]:
img_processor = ImageTransform(full_image_size, data_mean, data_std)

dataset_kwargs = {
    "root_path": os.path.join(data_path, "dicoms"),
    "metadata_path": os.path.join(data_path, "dicoms/DeepRDT-lung/metadata_lung_oldPat.csv"),
    "transform": img_processor,
    "max_workers": 4,
}

dataloaders = get_dataloaders(dataset_kwargs, channels=4, train_val_split=0.9)

In [None]:
def get_embed_dim(test_image):
    unit_batch = test_image.view(1, channels, full_image_size, full_image_size)
    with torch.no_grad():
        outputs = feature_model(unit_batch.to(device))
    _, _, embed_dim = outputs[0][0].shape
    return embed_dim
def test_loader_and_model():
    images, labels = next(iter(dataloaders["train_positives"]))
    embed_dim = get_embed_dim(images[0])
    print(images.shape)
    print("embed_dim:", embed_dim)

In [None]:
embed_dim = 768
EMBED_DIM = embed_dim * 4
PATCH_SIZE = config.student.patch_size

classifier_model = AggregateClassTokens(
    embed_dim=EMBED_DIM, hidden_dim=1024, num_labels=1
).to(device)

In [None]:
with torch.no_grad():
    dataloader = iter(dataloaders["train_positives"])
    inputs, label = next(dataloader)
    x_tokens_list = feature_model(inputs.to(device))
    class_tokens = extract_class_tokens(x_tokens_list).detach().cpu()
    nbytes = class_tokens.element_size() * class_tokens.numel()
    MB = nbytes / 1024 / 1024
    print(MB)

In [None]:
from sklearn.metrics import classification_report

def compute_metrics(predictions, labels):
    return classification_report(labels, predictions, target_names=['No Response', 'Response'])

In [None]:
EVAL_INTERVAL = 20
MAX_ITER = 50
BATCH_SIZE = 100

optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, MAX_ITER, eta_min=0)
criterion = torch.nn.BCEWithLogitsLoss()

class CircularList:
    def __init__(self, data: list):
        self.data = data
        self.index = 0

    def next(self):
        value = self.data[self.index]
        self.index = (self.index + 1) % len(self.data)
        return value

    def reset(self):
        self.index = 0

    def current(self):
        return self.data[self.index]

cls_cache = {}
with torch.no_grad():
    for loader_name in ["train_positives", "train_negatives"]:
        loader_cache = []
        print("Caching", loader_name)
        dataloader = iter(dataloaders[loader_name])
        for inputs, label in tqdm(dataloader):
            x_tokens_list = feature_model(inputs.to(device))
            class_tokens = extract_class_tokens(x_tokens_list)
            loader_cache.append(class_tokens.detach().cpu())
        cls_cache[loader_name] = CircularList(loader_cache)

print("Done caching.")


In [None]:
import torch.nn as nn
class AggregateClassTokens(nn.Module):
    def __init__(
        self,
        embed_dim=384 * 4,
        hidden_dim=1024,
        num_labels=1,
        device=torch.device("cpu"),
    ):
        super().__init__()
        self.linear = nn.Linear(embed_dim, hidden_dim, bias=False)
        self.attention_weights = nn.Linear(hidden_dim, 1)
        self.classifier = nn.Linear(hidden_dim, num_labels)

    def forward(self, class_tokens):
        x = self.linear(class_tokens)
        weights = self.attention_weights(x).squeeze(-1)
        weights = torch.softmax(weights, dim=0).unsqueeze(-1)
        
        attention_output = torch.sum(weights * x, dim=1)

        return self.classifier(attention_output)
classifier_model = AggregateClassTokens(
    embed_dim=EMBED_DIM, hidden_dim=1024, num_labels=1
).to(device)

In [None]:
def evaluate():
    pass

def train() -> int:

    classifier_model.train()

    for i in range(MAX_ITER):
        optimizer.zero_grad()

        if i % EVAL_INTERVAL == 0:
            evaluate()

        optimizer.zero_grad()
        for batch_idx in range(BATCH_SIZE):

            if batch_idx & 1 == 0:
                class_tokens = cls_cache["train_positives"].next().to(device)
                label = 1
            else:
                class_tokens = cls_cache["train_negatives"].next().to(device)
                label = 0

            logits = classifier_model(class_tokens)
            loss = criterion(logits, torch.tensor([[label]], dtype=torch.float32).to(device))
            loss.backward()

        torch.nn.utils.clip_grad_norm_(classifier_model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

train()