# **IDH Classification for Gliomas**


In [None]:
from IPython.display import clear_output

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

## Setup environment

In [None]:
!pip install monai
!pip install wandb
!pip install pytorch-ignite  # optional
!pip install transformers  # optional
!pip install einops
!pip install pydantic==1.10.11
!pip install lightning
# !pip install SimpleITK  # optional
clear_output()

## Setup imports

In [None]:
import os
import sys
import wandb
import shutil
import logging
import tempfile

import numpy as np
import pandas as pd
import lightning as L
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, classification_report

import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import DataLoader# as TorchDataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.config import print_config
from monai.data import ImageDataset#, DataLoader


pin_memory = torch.cuda.is_available()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# print_config()

In [None]:
ROOT = 'drive/MyDrive/Proyecto Gliomas'
os.chdir(ROOT)
# os.listdir()

In [None]:
GM_ROOT = f'data/Gregorio-Marañón'
UCSF_ROOT = f'data/TCIA'

In [None]:
GM_DIR = f'{GM_ROOT}/GM-BRATS+HM+NORM+CROPPED-NPZ'
UCSF_DIR = f'{UCSF_ROOT}/UCSF-NORM-LIGHT-CROPPED-NPZ'

In [None]:
df = pd.read_csv(f'participants.csv', index_col=0)
df.head(2)

In [None]:
def get_path(participant_id, database, modality, format_='nii.gz'):
    path = None
    if database == 'TCIA':
        path = f'{UCSF_DIR}/{participant_id}/anat/{participant_id}_{modality}.{format_}'
    elif database == 'GM':
        path = f'{GM_DIR}/{participant_id}/anat/{participant_id}_{modality}.{format_}'
    # assert os.path.exists(path), path
    return path

## Setup Images

In [None]:
# IMG_SIZE = (240, 240, 155)  # original TCIA image size
IMG_SIZE = (128, 128, 64)

In [None]:
df['T1w_ce_path'] = df.apply(lambda x: get_path(x['participant_id'], x['database'], 'ce-GADOLINIUM_T1w', 'npz'), axis=1)
df['FLAIR_path'] = df.apply(lambda x: get_path(x['participant_id'], x['database'], 'FLAIR', 'npz'), axis=1)

In [None]:
images = np.array([[path] for path in df['T1w_ce_path']])
# images = np.array([[path] for path in df['FLAIR_path']])

images = np.array(list(zip(df['T1w_ce_path'], df['FLAIR_path'])))
# images

## Setup Labels

In [None]:
def encode_labels(labels_):  # unused
    mapping = {
        (0, 0): 0,
        (1, 0): 1,
        (1, 1): 2,
        # (0, 1): 3,
    }
    # Use a list comprehension to map each row of the array
    return np.array([mapping[tuple(row)] for row in labels_])

In [None]:
# Sample labels for multi-class classification using binary pairs:
labels = np.array(list(df['idh_status']))
# labels = np.array(list(df['codeletion_1p19q_status']))

# labels = np.array(list(zip(df['idh_status'], df['codeletion_1p19q_status'])))
# labels = encode_labels(labels)
# labels

In [None]:
no_labels = len(np.unique(labels))
print(no_labels)

In [None]:
# labels = torch.nn.functional.one_hot(torch.as_tensor(labels)).float()
# labels[0], labels[2], labels[5]

#### [0, 0] --> IDH negative without codeletion
#### [0, 1] --> IDH negative with codeletion
#### [1, 0] --> IDH positive without codeletion
#### [1, 1] --> IDH positive with codeletion

## Setup Feature Vector

In [None]:
def get_feature_vector(numerical_features, categorical_features):
    # Normalize age
    scaler = StandardScaler().fit(np.array(*numerical_features).reshape(-1, 1))
    normalized_features = scaler.transform(np.array(*numerical_features).reshape(-1, 1))

    # One-hot encode sex, grade, and histologic_subtype
    encoder = OneHotEncoder(sparse_output=False).fit(np.column_stack(categorical_features))
    encoded_features = encoder.transform(np.column_stack(categorical_features))

    # Create single feature vector by concatenating normalized age with other encoded features
    feature_vector = np.hstack([normalized_features, encoded_features])
    return feature_vector

In [None]:
ages, sexes, grades = df['age'], df['sex'], df['who_cns_grade']

numerical_features = [ages]
categorical_features = [sexes]
# categorical_features = [sexes, grades]

feature_vector = get_feature_vector(numerical_features, categorical_features)
feature_vector.shape

## **Data Splitting**

In [None]:
random_state = 42

train_size = 0.8
test_size = 1.0 - train_size  # 0.2
val_size = 0.25

In [None]:
def split_indices(df_ids, test_size=0.2, val_size=0.25, random_state=42):
    train_ids_, test_ids_ = train_test_split(df_ids, test_size=test_size, random_state=random_state)
    train_ids_, val_ids_ = train_test_split(train_ids_, test_size=val_size, random_state=random_state)
    return train_ids_, val_ids_, test_ids_

def split_by_idh(df, idh_column='idh_status', test_size=0.2, val_size=0.25, random_state=42):
    idh_neg_ids = df[df[idh_column] == 0].index.to_numpy()
    idh_pos_ids = df[df[idh_column] == 1].index.to_numpy()
    train_idh_neg_ids, val_idh_neg_ids, test_idh_neg_ids = split_indices(idh_neg_ids, test_size=test_size, val_size=val_size, random_state=random_state)
    train_idh_pos_ids, val_idh_pos_ids, test_idh_pos_ids = split_indices(idh_pos_ids, test_size=test_size, val_size=val_size, random_state=random_state)

    print(f'IDH Negative --> Train: {len(train_idh_neg_ids)} / Validation: {len(val_idh_neg_ids)} / Test: {len(test_idh_neg_ids)}')
    print(f'IDH Positive --> Train: {len(train_idh_pos_ids)} / Validation: {len(val_idh_pos_ids)} / Test: {len(test_idh_pos_ids)}')

    train_ids = np.concatenate((train_idh_neg_ids, train_idh_pos_ids), axis=0)
    val_ids = np.concatenate((val_idh_neg_ids, val_idh_pos_ids), axis=0)
    test_ids = np.concatenate((test_idh_neg_ids, test_idh_pos_ids), axis=0)
    return train_ids, val_ids, test_ids

In [None]:
gm_df = df[df['database'] == 'GM']
gm_train_ids, gm_val_ids, gm_test_ids = split_by_idh(gm_df, test_size=test_size, val_size=val_size, random_state=random_state)

print(f'GM Database --> Train: {len(gm_train_ids)} / Validation: {len(gm_val_ids)} / Test: {len(gm_test_ids)}')
assert len(gm_train_ids) + len(gm_val_ids) + len(gm_test_ids)  == len(gm_df)  # 40

In [None]:
tcia_df = df[df['database'] == 'TCIA']
tcia_train_ids, tcia_val_ids, tcia_test_ids = split_by_idh(tcia_df, test_size=test_size, val_size=val_size, random_state=random_state)

print(f'TCIA Database --> Train: {len(tcia_train_ids)} / Validation: {len(tcia_val_ids)} / Test: {len(tcia_test_ids)}')
assert len(tcia_train_ids) + len(tcia_val_ids) + len(tcia_test_ids)  == len(tcia_df)  # 494

In [None]:
train_ids = np.concatenate((gm_train_ids, tcia_train_ids), axis=0)
val_ids = np.concatenate((gm_val_ids, tcia_val_ids), axis=0)
test_ids = np.concatenate((gm_test_ids, tcia_test_ids), axis=0)

print(f'TOTAL --> Train: {len(train_ids)} / Validation: {len(val_ids)} / Test: {len(test_ids)}')

In [None]:
s1, s2, s3 = set(train_ids), set(val_ids), set(test_ids)
assert not (s1 & s2 or s1 & s3 or s2 & s3)  # Assert that indices have no common values

In [None]:
train_imgs, train_labels, train_feature_vector = images[train_ids], labels[train_ids], feature_vector[train_ids]
val_imgs,   val_labels,   val_feature_vector   = images[val_ids],   labels[val_ids],   feature_vector[val_ids]
test_imgs,  test_labels,  test_feature_vector  = images[test_ids],  labels[test_ids],  feature_vector[test_ids]

In [None]:
gm_test_imgs,   gm_test_labels,   gm_test_features   = images[gm_test_ids],   labels[gm_test_ids],   feature_vector[gm_test_ids]
tcia_test_imgs, tcia_test_labels, tcia_test_features = images[tcia_test_ids], labels[tcia_test_ids], feature_vector[tcia_test_ids]

## **Data Augmentation**

In [None]:
from monai.transforms import (
    NormalizeIntensity,
    Compose,
    CropForeground,
    RandRotate90,
    RandZoom,
    RandAffine,
    RandScaleIntensity,
    RandShiftIntensity,
    RandGaussianNoise,
    RandAdjustContrast,
    RandGaussianSharpen,
    RandKSpaceSpikeNoise,
    ToTensor
)

spatial_transforms = Compose([
    CropForeground(select_fn=lambda x: x > 1, margin=10),
    RandRotate90(prob=0.25, spatial_axes=[0, 1]),    # Random 90-degree rotation
    RandRotate90(prob=0.25, spatial_axes=[1, 2]),    # Random 90-degree rotation
    RandZoom(prob=0.3, min_zoom=(1.0, 1.0), max_zoom=(1.2, 1.2)),
    RandAffine(                                      # Elastic deformation & rotation
        prob=0.25,
        rotate_range=(0, 0, np.pi/8),
        shear_range=(0.1, 0.1, 0.1),
        spatial_size=IMG_SIZE
    )
])

intensity_transforms_t1 = Compose([
    RandScaleIntensity(prob=0.5, factors=(0.8, 1.2)),   # Random intensity scaling for T1
    RandShiftIntensity(prob=0.5, offsets=(-20, 20)),
    RandAdjustContrast(prob=0.5, gamma=(0.9, 1.1)),
    RandGaussianSharpen(prob=0.3),
    RandGaussianNoise(prob=0.2, mean=0, std=0.1),       # Gaussian Noise for T1
    RandKSpaceSpikeNoise(prob=0.2)
])
intensity_transforms_flair = Compose([
    RandScaleIntensity(prob=0.5, factors=(0.7, 1.3)),   # Random intensity scaling for FLAIR
    RandShiftIntensity(prob=0.5, offsets=(-20, 20)),
    RandAdjustContrast(prob=0.5, gamma=(0.9, 1.1)),
    RandGaussianSharpen(prob=0.3),
    RandGaussianNoise(prob=0.2, mean=0, std=0.1),       # Gaussian Noise for FLAIR
    RandKSpaceSpikeNoise(prob=0.2)
])

intensity_transforms = [
    intensity_transforms_t1,
    intensity_transforms_flair
]

val_transforms = Compose([
    ToTensor()
])

## **Validation**

In [None]:
def validate(fabric, model, val_loader, epoch):
    """
    Validate a binary classification model.

    Parameters:
    - model: PyTorch model object.
    - dataloader: DataLoader for the validation dataset.

    Returns:
    - Average loss, accuracy, precision, recall, and F1-score on the validation set.
    """

    model.eval()
    total_loss = 0.0
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for i, batch in enumerate(val_loader):
            *_, labels = batch
            loss, predictions, probs = model.validation_step(batch, i)
            total_loss += loss.item()
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            print(f"Validating batch {i + 1}/{len(val_loader)} - Loss: {loss.item():.4f}")

    average_loss = total_loss / len(val_loader)
    accuracy_ = accuracy_score(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_probs)

    print(f"Validation finished. Average Loss: {average_loss:.4f}\tAccuracy: {accuracy_:.4f}")
    wandb.log({
        "val/epoch": epoch,
        "val/loss": average_loss,
        "val/accuracy": accuracy_,
        "val/roc_auc": roc_auc
    })
    return average_loss, accuracy_

## **Training**

In [None]:
from classification.model import LitModel
from classification.early_stopper import EarlyStopper
from classification.nets import EnhancedAttentionUnet, EnhancedDenseNet, EnhancedHighResNet, EnhancedResNet, EnhancedUNET, EnhancedUNETR, EnhancedVarAutoEncoder, EnhancedViT, EnhancedViTAutoEnc, EnhancedVNet

In [None]:
def train(fabric, model, train_loader, val_loader, optimizer, scheduler, model_name='model.pth'):
    torch.cuda.empty_cache()

    best_accuracy_epoch = -1
    best_accuracy = -1

    n_steps_per_epoch = np.ceil(len(train_loader.dataset) / model.batch_size)

    for epoch in range(model.num_epochs):
        print("-" * 10)
        print(f"Epoch {epoch + 1}/{model.num_epochs}")
        model.train()
        epoch_loss = 0

        for batch_idx, batch in enumerate(train_loader):
            optimizer.zero_grad()
            loss = model.training_step(batch, batch_idx)
            fabric.backward(loss)
            optimizer.step()
            epoch_loss += loss.item()
            print(f"{batch_idx + 1}/{int(n_steps_per_epoch)}, Train Loss: {loss.item():.4f}")

        epoch_loss /= len(train_loader)
        if scheduler:
            scheduler.step(epoch_loss)

        print(f"Epoch {epoch + 1} average loss: {epoch_loss:.4f}")
        wandb.log({
            'train/epoch': epoch,
            'train/loss': epoch_loss
        })

        if (epoch + 1) % model.val_interval != 0:
            continue

        # VALIDATION
        avg_val_loss, accuracy = validate(fabric, model, val_loader, epoch)

        if accuracy > best_accuracy:
            best_accuracy = accuracy
            best_accuracy_epoch = epoch + 1
            torch.save(model.state_dict(), model_name)
            print("Saved new best metric model")

        print(f"Best accuracy: {best_accuracy:.4f} at epoch {best_accuracy_epoch}")

        should_stop = early_stop_callback.on_validation_end(avg_val_loss)
        print(f'Patience: {early_stop_callback.patience}')
        if should_stop:
            break

    print(f"Training completed, best_accuracy: {best_accuracy:.4f} at epoch: {best_accuracy_epoch}")
    return best_accuracy

In [None]:
spatial_dims = 3  # 3D spatial dimensions
in_channels = 2  # 1 for each sequence (T1CE + FLAIR)
out_channels = 1  # 1 for binary classification 0 | 1
feature_dim = feature_vector.shape[1]  # 3
feature_dim

In [None]:
def get_model(model, in_channels, out_channels, feature_dim, config):
    if model == 'AttentionUnet':
        net = EnhancedAttentionUnet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            channels=(16, 32, 64),
            strides=(2, 2, 2),
            feature_dim=feature_dim
        )
        config["batch_size"] = 8
    elif model == 'DenseNet':
        net = EnhancedDenseNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_dim=feature_dim
        )
        config["batch_size"] = 32
    elif model == 'HighResNet':
        net = EnhancedHighResNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_dim=feature_dim,
        )
        config["batch_size"] = 2

    elif model == 'ResNet':
        net = EnhancedResNet(
            spatial_dims=spatial_dims,
            n_input_channels=in_channels,
            feature_dim=feature_dim,
            block='basic',
            layers=[3, 4, 6, 3],
            block_inplanes=[64, 128, 256, 512],
            conv1_t_stride=2,
            num_classes=out_channels,
        )
        # config["batch_size"] = 2
    elif model == 'UNET':
        net = EnhancedUNET(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_dim=feature_dim,
            channels=(4, 8, 16, 32, 64),
            strides=(2, 2, 2, 2),
        )
        config["batch_size"] = 128
    elif model == 'UNETR':
        net = EnhancedUNETR(
            in_channels=in_channels,
            out_channels=out_channels,
            img_size=IMG_SIZE,
            spatial_dims=spatial_dims,
            feature_dim=feature_dim
        )
        config["batch_size"] = 8
    elif model == 'VarAutoEncoder':
        net = EnhancedVarAutoEncoder(
            spatial_dims=spatial_dims,
            feature_dim=feature_dim,
            in_shape=(2, *IMG_SIZE),
            out_channels=out_channels,
            latent_size=3,
            channels=(8, 16, 32, 64),
            strides=(1, 2, 2, 2),
        )
        config["batch_size"] = 32
    elif model == 'ViT':
        net = EnhancedViT(
            spatial_dims=spatial_dims,
            img_size=IMG_SIZE,
            in_channels=in_channels,
            num_classes=out_channels,
            patch_size=(16, 16, 16),
            hidden_size=768,
            mlp_dim=3072,
            feature_dim=feature_dim,
            classification=True
        )
        config["batch_size"] = 8
    elif model == 'VitAutoEnc':
        net = EnhancedViTAutoEnc(
            img_size=IMG_SIZE,
            patch_size=(16, 16, 16),
            hidden_size=768,
            deconv_chns=16,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_dim=feature_dim
        )
        config["batch_size"] = 64
    elif model == 'VNet':
        net = EnhancedVNet(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=out_channels,
            feature_dim=feature_dim
        )
        config["batch_size"] = 8
    else:
        raise ValueError(f'Unknown model name: ')

    model_ = LitModel(model=net, **config)
    model_.name = model
    return model_

In [None]:
fabric = L.Fabric(
    accelerator='cuda', devices=1,
    strategy="auto",
    # callbacks=[early_stop_callback],
)
fabric.launch()

In [None]:
MODEL = 'DenseNet'
NUM_EPOCHS = 100
LOSS_FUNCTION = nn.BCEWithLogitsLoss()  # nn.CrossEntropyLoss()

In [None]:
KEY = 'YOUR_WANDB_KEY'
wandb.login(key=KEY)
wandb_logger = wandb.init(project="idh-status")#, name=f'{MODEL}')

In [None]:
config = {
    "num_epochs": NUM_EPOCHS,
    "batch_size": 16,
    'optimizer': torch.optim.Adam,
    "lr": 1e-4,
    'loss_func': LOSS_FUNCTION,
    "val_interval": 1,
}

In [None]:
model = get_model(
    MODEL,
    in_channels=in_channels,
    out_channels=out_channels,
    feature_dim=feature_dim,
    config=config
)

In [None]:
early_stop_callback = EarlyStopper(
    stopping_threshold=0.05,
    patience=10
)
# model.early_stop = early_stop_callback

In [None]:
num_workers = 4

train_loader = model.train_dataloader(
    train_imgs, train_labels, train_feature_vector,
    num_workers=num_workers,
    spatial_transforms=spatial_transforms,
    intensity_transforms=intensity_transforms
)
val_loader = model.val_dataloader(
    val_imgs, val_labels, val_feature_vector,
    num_workers=num_workers,
    transforms=val_transforms
)

In [None]:
model_name = f'models/{MODEL.lower()}.pth'

In [None]:
optimizer = model.configure_optimizers()

model_, optimizer = fabric.setup(model, optimizer)
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)

In [None]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.1, verbose=True)
# scheduler = None

In [None]:
best_accuracy = train(
    fabric, model_,
    train_loader, val_loader,
    optimizer=optimizer,
    scheduler=scheduler,
    model_name=model_name
)

## **Testing**

In [None]:
def load_model(model, checkpoint_path):
    """
    Load the model checkpoint from the given path.
    """
    model.load_state_dict(torch.load(checkpoint_path))
    return model


def split_probs_by_class(probs):
    probs = np.array(probs)[:, 0]
    all_probs_neg = np.where(probs >= 0.5, 1 - probs, probs)
    all_probs_pos = 1 - all_probs_neg
    return np.column_stack((all_probs_neg, all_probs_pos))


def test_model(fabric, model, test_loader):
    # Set model to evaluation mode
    model.eval()

    # Store all predictions and true labels
    all_preds = []
    all_labels = []
    all_probs = []
    verbose = test_loader.dataset.verbose

    # No gradient computation
    with torch.no_grad():
        for i, batch in enumerate(test_loader):
            if verbose:
                name = batch.pop()
            preds, probs, labels = model.test_step(batch, i)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            print(f'Testing batch {i + 1}/{len(test_loader)}')
            if verbose:
                print(f'Name: {name} --> Label: {int(labels.cpu().numpy()[0])} / Pred: {int(preds.cpu().numpy()[0][0])}')


    # Calculate accuracy
    accuracy_ = accuracy_score(all_labels, all_preds)
    precision_ = precision_score(all_labels, all_preds)
    recall_ = recall_score(all_labels, all_preds)
    f1_score_ = f1_score(all_labels, all_preds)
    roc_auc = roc_auc_score(all_labels, all_probs)

    # Get detailed classification report
    report = classification_report(all_labels, all_preds, target_names=['IDH-Neg', 'IDH-Pos'])

    # split_probs = split_probs_by_class(all_probs)

    print(f"Overall Accuracy: {accuracy_ * 100:.2f}%")
    print(f"Overall Precision: {precision_ * 100:.2f}%")
    print(f"Overall Recall: {recall_ * 100:.2f}%")
    print(f"Overall F1 Score: {f1_score_ * 100:.2f}%")
    print(f"Overall ROC-AUC: {roc_auc * 100:.2f}%")

    test_metrics = {
        f"test/{model.name}_accuracy": accuracy_,
        f"test/{model.name}_precision": precision_,
        f"test/{model.name}_recall": recall_,
        f"test/{model.name}_f1_score": f1_score_,
        f"test/{model.name}_roc_auc": roc_auc,
        # f"test/{model.name}_roc_auc_plot": wandb.plot.roc_curve(all_labels, split_probs, labels=['IDH-Neg', 'IDH-Pos']),
        f"test/{model.name}_confusion_matrix": wandb.sklearn.plot_confusion_matrix(
            all_labels, all_preds, ['IDH-Neg', 'IDH-Pos']
        )
    }
    return test_metrics, report

DenseNet --> 6 features (age, sex, grades)

HighResNet --> 6 features (age, sex, grades)

AttentionUnet --> 3 features (age, sex)

VitAutoEnc --> 6 features (age, sex, grades)

UNETR --> 6 features (age, sex, grades)

Single testing

In [None]:
MODEL = 'DenseNet'

In [None]:
model_name = f'models/{MODEL.lower()}_{NUM_EPOCHS}.pth'
model_name = f'models/densenet_100.pth'

In [None]:
model = get_model(
    MODEL,
    in_channels=in_channels,
    out_channels=out_channels,
    feature_dim=feature_dim,
    config=config
)

In [None]:
test_loader      = model.test_dataloader(test_imgs, test_labels, test_feature_vector)
gm_test_loader   = model.test_dataloader(gm_test_imgs, gm_test_labels, gm_test_features)
tcia_test_loader = model.test_dataloader(tcia_test_imgs, tcia_test_labels, tcia_test_features)

In [None]:
eval_model = load_model(model, model_name)

metrics, report = test_model(fabric, eval_model, test_loader)
wandb.log(metrics)

In [None]:
print(report)

Multiple testing

In [None]:
MODELS = ['DenseNet', 'HighResNet', 'AttentionUnet', 'VitAutoEnc', 'UNETR']

In [None]:
import time

results = dict.fromkeys(MODELS, {})

for model_name in MODELS:
    st = time.time()
    wandb.init(project="idh-status", name=f'{model_name}_test')
    print(f'Evaluating {model_name}...')

    ckpt_name  = f'models/{model_name.lower()}.pth'
    model = get_model(
        model_name,
        in_channels=in_channels,
        out_channels=out_channels,
        feature_dim=feature_dim,
        config=config
    )
    test_loader = model.test_dataloader(test_imgs, test_labels, test_feature_vector)
    # test_loader = model.test_dataloader(gm_test_imgs, gm_test_labels, gm_test_features)  # GM
    # test_loader = model.test_dataloader(tcia_test_imgs, tcia_test_labels, tcia_test_features)  # TCIA
    eval_model = load_model(model, ckpt_name)
    metrics, report = test_model(fabric, eval_model, test_loader)
    # wandb.log(metrics)
    results[model_name]['metrics'] = metrics
    results[model_name]['report'] = report
    wandb.finish()

    seconds = time.time() - st
    print(f'Elapsed time {seconds} seconds')
    print(f'No. images per second {len(test_loader) / seconds} seconds')