In [None]:
%cd ..
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import random
import os
import torch
from tqdm import tqdm
from dotenv import load_dotenv

from evaluation.utils.networks import FullScanPredictor
from evaluation.extended_datasets.deeprdt_lung import DeepRDT_lung
from evaluation.extended_datasets.covid_19_ny_sbu import COVID_19_NY_SBU
from evaluation.utils.dataset import BalancedSampler, split_dataset, CombinedDataset
from evaluation.utils.metrics import compute_metrics

np.random.seed(42)
random.seed(42)

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

run_name = "base4_103x4x5"
checkpoint_name = "training_659999"
device = torch.device("cuda:0")

In [None]:
metadata_path = os.path.join(data_path, "dicoms/DeepRDT-lung/metadata_lung_oldPat.csv")
dataset_control = DeepRDT_lung(metadata_path, run_name, checkpoint_name, label="False")

dataset_positive = COVID_19_NY_SBU(run_name, checkpoint_name)

combined = CombinedDataset(dataset_control, dataset_positive)
train_dataset, val_dataset = split_dataset(combined, 0.8)

In [None]:
def collate_fn(batch):
    embeddings, labels = zip(*batch)

    _, num_tokens, embed_dim = embeddings[0].shape

    max_length = max([embedding.shape[0] for embedding in embeddings])

    padded_embeddings = torch.zeros(len(embeddings), max_length, num_tokens, embed_dim)
    mask = torch.zeros(len(embeddings), max_length)

    for i, embedding in enumerate(embeddings):
        padded_embeddings[i, : embedding.shape[0]] = embedding
        mask[i, : embedding.shape[0]] = 1

    return padded_embeddings, mask, torch.tensor(labels)


train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    collate_fn=collate_fn,
    sampler=BalancedSampler(train_dataset),
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
)

In [None]:
embed_dim = 768
combined_embed_dim = embed_dim * 4

hidden_dim = 1024

classifier_model = FullScanPredictor(combined_embed_dim, hidden_dim, num_labels=1).to(
    device
)

In [None]:
train_epochs = 100

accum_steps = 32

optimizer = torch.optim.SGD(
    classifier_model.parameters(), momentum=0.9, weight_decay=0.01, lr=1e-3
)
criterion = torch.nn.BCEWithLogitsLoss()
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epochs, eta_min=1e-5)


def evaluate():
    classifier_model.eval()
    all_logits = []
    all_labels = []

    with torch.no_grad():
        val_iter = iter(val_loader)
        for idx, (embeddings, _, labels) in enumerate(tqdm(val_iter)):
            embeddings = embeddings.to(device)
            logits = classifier_model(embeddings)

            all_logits.append(logits.flatten().cpu().numpy())
            all_labels.append(labels.float().cpu().numpy())

    all_logits = np.array(all_logits)
    all_labels = np.array(all_labels)

    return compute_metrics(all_logits, all_labels)


def train() -> int:

    classifier_model.train()
    losses = []
    grad_norms = []
    val_metrics = []

    for train_epoch in range(train_epochs):

        total_loss = 0.0
        total_grad_norm = 0.0

        optimizer.zero_grad()

        train_iter = iter(train_loader)

        for idx, (embeddings, _, labels) in enumerate(tqdm(train_iter)):
            logits = classifier_model(embeddings.to(device))
            loss = criterion(logits.flatten(), labels.float().to(device))
            total_loss += loss.item()
            loss.backward()

            total_grad_norm += torch.nn.utils.clip_grad_norm_(
                classifier_model.parameters(), 1.0
            ).item()

            if (idx + 1) % accum_steps == 0 or idx == len(train_loader):
                optimizer.step()
                optimizer.zero_grad()

        losses.append(total_loss / len(train_loader))
        grad_norms.append(total_grad_norm / len(train_loader))

        epoch_metrics = evaluate()
        val_metrics.append(epoch_metrics)

        accuracy = epoch_metrics["accuracy"]
        rocauc = epoch_metrics["aucroc"]

        print(
            f"Epoch {train_epoch + 1}/{train_epochs} - Accuracy: {accuracy:.4f} - ROCAUC: {rocauc:.4f}"
        )

    return losses, grad_norms, val_metrics

In [None]:
losses, grad_norms, val_metrics = train()

In [None]:
val_metrics[-1]