In [None]:
!pip install -q 'monai[all]'

In [None]:
!pip install monai==0.9

In [22]:
def precision_and_recall(pred, target):

    smooth = 1e-5 # prevent division by 0

    tp = torch.sum(pred * target)  # TP
    fp = torch.sum(pred * (1 - target))  # FP
    fn = torch.sum((1 - pred) * target)  # FN

    precision = tp/(tp + fp + smooth)
    recall = tp/(tp + fn + smooth)

    return float(precision), float(recall)

In [2]:
import os
import gc
import glob
from pathlib import Path

import numpy as np
import torch
from torch.nn import functional as F
 
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    Rotated,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    EnsureTyped,
    EnsureType,
    Invertd,
    Rotate90d,
    RandRotated,
    RandShiftIntensityd,
    RandGaussianNoised
)
# from monai.handlers.utils import from_engine
from UNet3D import UNet
from AttUNet import AttUNet
from TransUNet import TransUNet
from monai.networks.nets import SwinUNETR
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch

In [3]:
from meaniou import MeanIoU

# Change Device to Cuda

In [4]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

#Additional Info when using cuda
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print(torch.cuda.get_device_properties(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')

Using device: cuda

Tesla T4
_CudaDeviceProperties(name='Tesla T4', major=7, minor=5, total_memory=14910MB, multi_processor_count=40)
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


# Load Data

In [5]:
def get_paths(mask):
    data_dir = os.path.join(os.getcwd(), 'data/train')
    tmp = []
    for path in Path(data_dir).rglob(mask):
        tmp.append(path.resolve())
    return tmp

In [25]:
training_dir = os.path.join(os.getcwd(), 'metrics')

checkpoint = torch.load('postpretrained-att-unet_model_and_optim')#,map_location=torch.device('cpu')) 

import pickle
with open("datasplit", "rb") as f:
    saved_lists = pickle.load(f)
    
train_files, val_files = saved_lists

In [None]:
checkpoint['epoch']

In [None]:
print(len(train_files))
print(len(val_files))

# Transforms using Monai

In [7]:
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]), # Load image file or files from provided path based on reader.
        EnsureChannelFirstd(keys=["image", "label"]), #adds a channel dimension if the data doesn't have one ... torch.Size([1, ...]) = torch.Size([1, 1, ...
        Orientationd(keys=["image", "label"], axcodes="LPS"),
        Rotate90d(keys=["image", "label"], k=1, spatial_axes=(0,2)), # rotate data so it looks like it should do? ... doesn't feel right when viewing otherwise
        ScaleIntensityRanged(
            keys=["image"], a_min=0.0, a_max=302.0,
            b_min=0.0, b_max=1.0, clip=True,
        ),
        EnsureTyped(keys=["image", "label"]) # converts the data to a pytorch tensor
    ]
)
        
cropper = RandCropByPosNegLabeld(
    keys=["image", "label"],
    label_key="label",
    spatial_size=(96, 96, 96),   # provides size of each image within the batch
    pos=1,      # pos / (pos + neg) = ratio of postivie and negative samples picked... 
    neg=1,      # with pos = neg = 1, ratio = 0.5 so it picks equal pos (stoke) and neg (no stroke) for sample.
    num_samples=4,   # number of smaller volumes to create from the original volume
    image_key="image",
    image_threshold=0,
)

# Helper Functions

In [8]:
def rand_crop(images, labels, cropper):
    cropped_inputs = []
    cropped_labels = []
    assert images.shape[0] == labels.shape[0]
    for i in range(images.shape[0]):
        pair =  {"image": images[i], "label":labels[i]}
        out = cropper(pair)
        for i in range(len(out)):
            cropped_inputs.append(out[i]['image'])
            cropped_labels.append(out[i]['label'])
    imgs = torch.stack(cropped_inputs)
    lbls = torch.stack(cropped_labels)
    return imgs, lbls

In [9]:
if device.type == 'cuda':
    val_ds = CacheDataset(
        data=val_files, 
        transform=val_transforms, 
        cache_rate=1.0, 
        num_workers=4
    )
else:
    val_ds = Dataset(data=val_files, transform=transforms)
    
val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)

Loading dataset: 100%|██████████| 130/130 [00:26<00:00,  4.93it/s]


In [24]:
# model = UNet().to(device)

# model = AttUNet().to(device)

# model = TransUNet().to(device)

model = SwinUNETR(img_size=(96,96,96), in_channels=1, out_channels=2, feature_size=48).to(device)

In [26]:
# model.load_state_dict(torch.load(
#     os.path.join(training_dir, "best_metric_model.pth")))
model.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

In [16]:
dice_metric = DiceMetric(include_background=False, reduction="mean")
iou_metric = MeanIoU(include_background=False, reduction="mean")

In [None]:
post_pred = Compose([EnsureType(), AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([EnsureType(), AsDiscrete(to_onehot=2)])
pres_recall = []

model.eval()
with torch.no_grad():
    for val_data in val_loader:
        val_inputs, val_labels = (
            val_data["image"].to(device),
            val_data["label"].to(device),
        )
        roi_size = (96, 96, 96)
        sw_batch_size = 4
        val_outputs = sliding_window_inference(
            val_inputs, roi_size, sw_batch_size, model)
        output = torch.argmax(val_outputs, dim=1) # convert the prediction into 0 and 1s
        precision, recall = precision_and_recall(pred=output, target=val_labels)
        pres_recall.append([precision, recall])
        val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
        val_labels = [post_label(i) for i in decollate_batch(val_labels)]
        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        iou_metric(y_pred=val_outputs, y=val_labels)
    # aggregate the final mean dice result
    dice_score = dice_metric.aggregate().item()
    iou_score = iou_metric.aggregate().item()
    # reset the status for next validation round
    dice_metric.reset()
    iou_metric.reset()
    print(dice_score)
    print(iou_score)

pres_recall = (np.asarray(pres_recall).sum(axis=0)) / len(pres_recall)
print(pres_recall)