In [None]:
%cd ../../..
import warnings

warnings.filterwarnings("ignore")

import numpy as np
import random
import os
import torch
from dotenv import load_dotenv

from evaluation.utils.networks import FullScanClassPredictor
from evaluation.tasks.ct_rate.datasets import CT_RATE
from evaluation.utils.dataset import BalancedSampler, split_dataset
from evaluation.utils.train import train

np.random.seed(42)
random.seed(42)

In [None]:
load_dotenv()
project_path = os.getenv("PROJECTPATH")
data_path = os.getenv("DATAPATH")

"""
DINO:
run_name = "vitb-14.10"
checkpoint_name = "training_319999"
cache_format = "npy"
embed_dim = 768 * 4

Diffusion:
run_name = "diffusion_new240"
checkpoint_name = "unnamed_checkpoint"
cache_format = "npz"
embed_dim = 512
"""

run_name = "diffusion_new240"
checkpoint_name = "unnamed_checkpoint"
cache_format = "npz"
embed_dim = 512

device = torch.device("cuda:0")

# AUC-ROC Scores
**DINO**
- Lung nodule (0.6302)
- Lung opacity (0.7028)
- Arterial wall calcification (0.8299)
- Pulmonary fibrotic sequela (0.6619)

**Diffusion**
- Lung nodule (0.5136)
- Lung opacity (0.5602)
- Arterial wall calcification (0.5366)
- Pulmonary fibrotic sequela (0.5000)

**CT-CLIP**
- Lung nodule ()
- Lung opacity ()
- Arterial wall calcification ()
- Pulmonary fibrotic sequela ()



In [None]:
metadata_path = os.path.join(data_path, "CT-RATE/multi_abnormality_labels/valid_predicted_labels.csv")
label= "Pulmonary fibrotic sequela"

dataset = CT_RATE(
    run_name,
    checkpoint_name,
    metadata_path,
    label,
    cache_format=cache_format,
)

train_dataset, val_dataset = split_dataset(dataset, 0.8)

In [None]:
def collate_fn(batch):
    embeddings, labels = zip(*batch)

    _, num_tokens, embed_dim = embeddings[0].shape

    max_length = max([embedding.shape[0] for embedding in embeddings])
    
    padded_embeddings = torch.zeros(len(embeddings), max_length, num_tokens, embed_dim)
    mask = torch.zeros(len(embeddings), max_length)

    for i, embedding in enumerate(embeddings):
        padded_embeddings[i, :embedding.shape[0]] = embedding
        mask[i, :embedding.shape[0]] = 1

    return padded_embeddings, mask, torch.tensor(labels)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=1,
    collate_fn=collate_fn,
    sampler=BalancedSampler(train_dataset),
)
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=1,
    shuffle=False,
    collate_fn=collate_fn,
)

In [None]:
hidden_dim = 1024

classifier_model = FullScanClassPredictor(
    embed_dim, hidden_dim, num_labels=1
).to(device)

In [None]:
train_epochs = 20
accum_steps = 32

optimizer = torch.optim.SGD(classifier_model.parameters(), momentum=0.9, weight_decay=0.01, lr=1e-3)
criterion = torch.nn.BCEWithLogitsLoss()
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_epochs, eta_min=1e-5)

losses, grad_norms, val_metrics = train(
    classifier_model,
    optimizer,
    criterion,
    train_loader,
    val_loader,
    train_epochs,
    accum_steps,
    device
)