In [None]:
# necessary imports
import os

import numpy as np
import torch
import tqdm
from monai.data import Dataset, decollate_batch
from monai.metrics import DiceMetric
from monai.networks.nets import UNETR
from monai.transforms import Activations, AsDiscrete, Compose
from torch.utils.data import DataLoader

from bmmae.augmentations import get_val_seg_transforms

In [None]:
def eval_loop(model, loader, device, threshold=0.5):
    dice_metric = DiceMetric(include_background=True, reduction="mean_batch")
    model.eval()
    with torch.no_grad():
        post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=threshold)])
        for batch_data in tqdm.tqdm(loader, desc="Validation..."):
            inputs, label = (batch_data["image"].to(device), batch_data["label"].to(device))
            outputs = model(inputs)

            try:
                outputs = outputs.as_tensor()
            except:
                pass
                # put everything on cpu
            outputs = outputs.cpu()
            label = label.cpu()
            outputs = [post_trans(i) for i in decollate_batch(outputs)]
            dice_metric(y_pred=outputs, y=label)

        dice = dice_metric.aggregate("none").numpy()

        mean_dice = np.nanmean(dice, axis=0) * 100
        std_dice = np.nanstd(dice, axis=0) * 100
    

        mean_dice_tc = mean_dice[0]
        mean_dice_wt = mean_dice[1]
        mean_dice_et = mean_dice[2]

        std_dice_tc = std_dice[0]
        std_dice_wt = std_dice[1]
        std_dice_et = std_dice[2]

        print(
            f"DICE"
            f"\n et: {mean_dice_et:.1f} ± {std_dice_et:.1f}"
            f"\n tc: {mean_dice_tc:.1f} ± {std_dice_tc:.1f}"
            f"\n wt: {mean_dice_wt:.1f} ± {std_dice_wt:.1f}"
        )

    return dice

# Test Wilcoxon
The code below is used to perform the Wilcoxon test between the model with pre-training and the baselines, to actually observe if there is a significative improvement or not.

In [None]:
modalities = ["t1", "t1ce", "t2", "flair"]
val_patients = os.listdir(os.path.join("../rsna/BraTS2021", "val"))
val_files = [
    {
        "image": [
            os.path.join("../rsna/BraTS2021/val/", patient, f"{patient}_{modality}.nii.gz") for modality in modalities
        ],
        "label": os.path.join("../rsna/BraTS2021/val/", patient, f"{patient}_seg.nii.gz"),
    }
    for patient in val_patients
]
val_dataset = Dataset(data=val_files, transform=get_val_seg_transforms())


val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=8, pin_memory=True)

model = UNETR(
    in_channels=len(modalities),
    out_channels=3,
    img_size=(128, 128, 128),
    hidden_size=768,
    mlp_dim=1536,
    num_heads=12,
    proj_type="conv",
    qkv_bias=True,
)

if len(modalities) > 1:
    placeholder = "-".join(modalities)
else:
    placeholder = modalities[0]

from bmmae.model import BMMAEViT
from bmmae.tokenizers import MRITokenizer

tokenizers = {
    modality: MRITokenizer(patch_size=(16, 16, 16), img_size=(128, 128, 128), hidden_size=768)
    for modality in modalities
}

vit = BMMAEViT(
    modalities=modalities,
    tokenizers=tokenizers,
    hidden_size=768,
    mlp_dim=1536,
    num_heads=12,
    qkv_bias=True,
    classification=False
)

model.vit = vit
model.load_state_dict(torch.load(f"pretrained_models/bmmae/seg_{placeholder}_False.pth", weights_only=True))


model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
dice_bmmae = eval_loop(model, val_loader, device)

## FROM SCRATCH
model = UNETR(
    in_channels=len(modalities),
    out_channels=3,
    img_size=(128, 128, 128),
    hidden_size=768,
    mlp_dim=1536,
    num_heads=12,
    proj_type="conv",
    qkv_bias=True,
)
model.load_state_dict(torch.load(f"pretrained_models/fs/seg_{placeholder}_True.pth", weights_only=True))
model.to("cuda:0")
model.eval()
dice_fs = eval_loop(model, val_loader, device)


## REGMAE
model = UNETR(
    in_channels=len(modalities),
    out_channels=3,
    img_size=(128, 128, 128),
    hidden_size=768,
    mlp_dim=1536,
    num_heads=12,
    proj_type="conv",
    qkv_bias=True,
)
model.load_state_dict(torch.load(f"pretrained_models/REGMAE/seg_{placeholder}_False.pth", weights_only=True))
model.to("cuda:0")
model.eval()
dice_mae = eval_loop(model, val_loader, device)


## SIMCLR
model = UNETR(
    in_channels=len(modalities),
    out_channels=3,
    img_size=(128, 128, 128),
    hidden_size=768,
    mlp_dim=1536,
    num_heads=12,
    proj_type="conv",
    qkv_bias=True,
)
model.load_state_dict(torch.load(f"pretrained_models/SIMCLR/seg_{placeholder}_False.pth", weights_only=True))
model.to("cuda:0")
model.eval()
dice_simclr = eval_loop(model, val_loader, device)

In [None]:
from scipy.stats import wilcoxon


def clean_nan(x, y=None):
    # get both mask of x and y
    mask_x = ~np.isnan(x)
    mask_y = ~np.isnan(y) if y is not None else mask_x
    # get the mask of the intersection of x and y
    mask = mask_x & mask_y
    return x[mask]


print("------------------ REGMAE --------------------------")
print("p-value for DICE TC : ", wilcoxon(clean_nan(dice_bmmae[:, 0]), clean_nan(dice_mae[:, 0])).pvalue.round(5))
print("p-value for DICE WT : ",wilcoxon(clean_nan(dice_bmmae[:, 1]), clean_nan(dice_mae[:, 1])).pvalue.round(5))
print("p-value for DICE ET : ",wilcoxon(clean_nan(dice_bmmae[:, 2]), clean_nan(dice_mae[:, 2])).pvalue.round(5))
print("------------------ SIMCLR --------------------------")
print("p-value for DICE TC : ", wilcoxon(clean_nan(dice_bmmae[:, 0]), clean_nan(dice_simclr[:, 0])).pvalue.round(5))
print("p-value for DICE WT : ",wilcoxon(clean_nan(dice_bmmae[:, 1]), clean_nan(dice_simclr[:, 1])).pvalue.round(5))
print("p-value for DICE ET : ",wilcoxon(clean_nan(dice_bmmae[:, 2]), clean_nan(dice_simclr[:, 2])).pvalue.round(5))
print("------------------ FROM SCRATCH --------------------------")
print("p-value for DICE TC : ", wilcoxon(clean_nan(dice_bmmae[:, 0]), clean_nan(dice_fs[:, 0])).pvalue.round(5))
print("p-value for DICE WT : ",wilcoxon(clean_nan(dice_bmmae[:, 1]), clean_nan(dice_fs[:, 1])).pvalue.round(5))
print("p-value for DICE ET : ",wilcoxon(clean_nan(dice_bmmae[:, 2]), clean_nan(dice_fs[:, 2])).pvalue.round(5))