In [None]:
%cd ../../..
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import random
import os
import torch
from dotenv import load_dotenv

from evaluation.utils.networks import FullScanClassPredictor, FullScanPatchPredictor

from evaluation.extended_datasets import CachedEmbeddings, EmbeddingsGenerator
from evaluation.tasks.ct_rate.datasets import CT_RATE

from evaluation.utils.dataset import BalancedSampler, split_dataset, collate_sequences
from evaluation.utils.train import train, evaluate

np.random.seed(42)
random.seed(42)

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

run_name = "vitb_CT-RATE"
checkpoint_name = "training_549999"
embed_dim = 768

In [None]:
# Use this to generate embeddings during runtime

dataset_path = os.path.join(data_path, "niftis/CT-RATE")

embeddings_provider = EmbeddingsGenerator(
    project_path,
    run_name,
    checkpoint_name,
    dataset_path,
    dataset_name="CT-RATE_train_eval",
    db_storage="nifti",
    device=torch.device("cuda:0"),
    embed_patches=False,
    embed_cls=True,
    max_batch_size=64,
    resample_slices=240
)

In [None]:
# Use this to use cached embeddings
embeddings_path = os.path.join(project_path, "evaluation/cache/CT-RATE_train_eval", run_name, checkpoint_name)

embeddings_provider = CachedEmbeddings(embeddings_path)

In [None]:
metadata_path = os.path.join(data_path, "niftis/CT-RATE/multi_abnormality_labels/train_predicted_labels.csv")
label= "Arterial wall calcification"

dataset = CT_RATE(
    embeddings_provider,
    metadata_path,
    label,
)

train_dataset, val_dataset = split_dataset(dataset, 0.8)
print(len(train_dataset), len(val_dataset))

In [None]:
batch_size = 1

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=collate_sequences,
    sampler=BalancedSampler(train_dataset),
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_sequences,
    sampler=BalancedSampler(val_dataset),
)
print(len(train_loader), len(val_loader))

In [None]:
device = torch.device("cuda:0")

hidden_dim = 512

classifier_model = FullScanClassPredictor(embed_dim, hidden_dim, num_labels=1)

classifier_model.to(device)
print("Model loaded")

In [None]:
device = torch.device("cuda:0")

hidden_dim = 512
patch_resample_dim = 16

classifier_model = FullScanPatchPredictor(
    embed_dim, hidden_dim, num_labels=1, patch_resample_dim=patch_resample_dim
)
#classifier_model = torch.nn.DataParallel(classifier_model, device_ids=[0, 1, 2, 3])
classifier_model.to(device)
print("Model loaded")

In [None]:
accum_steps = 32
train_iters = accum_steps*4

optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0.01, lr=1e-3)
criterion = torch.nn.BCEWithLogitsLoss()
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epochs, eta_min=1e-5)

def train_test():
    train_metrics = train(
        classifier_model,
        optimizer,
        criterion,
        train_loader,
        train_iters,
        accum_steps,
        device
    )
    print(
        f"PR AUC: {train_metrics['pr_auc']:.4f} - ROC AUC: {train_metrics['roc_auc']:.4f}\n"
    )

    eval_metrics = evaluate(
        classifier_model, val_loader, device=device, max_eval_n=100
    )
    print(
        f"PR AUC: {eval_metrics['pr_auc']:.4f} - ROC AUC: {eval_metrics['roc_auc']:.4f}\n"
    )


In [None]:
for i in range(10):
    train_test()
