In [1]:
import io
import math
from PIL import Image
import os
import wandb
import glob
import torch
import monai
import random
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import SimpleITK as sitk
import warnings
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
from medpy.metric.binary import hd, dc

In [2]:
warnings.filterwarnings('ignore')
torch.manual_seed(1024)
np.random.seed(1024)
device = torch.device(
    "cuda:1"
    if torch.cuda.is_available()
    else
    "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda:1 device


In [3]:
class SegDataset(Dataset):
    def __init__(self,
                 data_root,
                 include_background=True,
                 val=True,
                 to3d=False):
        self.data_root = data_root
        self.val = val
        self.to3d = to3d
        self.gt_files_path = []
        self.include_background = include_background
        self.transform = transforms.Compose([
                                transforms.ToPILImage(),
                                transforms.ToTensor(),
                                transforms.Normalize([0.45, ], [0.35, ]),
                                transforms.Resize([256, 256])])

        # find all patient directories
        patient_directories = glob.glob(os.path.join(self.data_root, 'patient*'))
        # find all files with the suffix _gt.npy
        train_size = int(len(patient_directories)*0.8)

        if self.val:
            for patient_directory in patient_directories[train_size:]:
                per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.nii'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)
        else:
            for patient_directory in patient_directories:
                per_patient_file_path = glob.glob(os.path.join(patient_directory, '*_gt.nii'))
                for path in per_patient_file_path:
                    self.gt_files_path.append(path)

    def __len__(self):
        return len(self.gt_files_path)

    def __getitem__(self, index):
        gt_image_path = self.gt_files_path[index]
        image_path = gt_image_path[:-7] + ".nii"
        gt_image_array = sitk.GetArrayFromImage(sitk.ReadImage(gt_image_path))
        image_array = sitk.GetArrayFromImage(sitk.ReadImage(image_path))
        
        
        # Create an empty list to store transformed images
        transformed_images = []
        
        for image in image_array:
            image = torch.tensor(image[None,:,:]).float()
            
            transformed_image = self.transform(image)
            
            transformed_images.append(transformed_image)

        # Stack the transformed images back into a tensor
        transformed_tensor = torch.stack(transformed_images)


        return transformed_tensor, gt_image_array, gt_image_path

In [4]:
%matplotlib inline


def vis_img(img, mask):
    # img: (B, 256, 64, 64), {: (B, 1, 256, 256)
    print(f"{img.shape=}, {mask.shape=}")
    img = np.squeeze(img)
    mask = np.squeeze(mask)
    plt.figure()
    plt.imshow(img, 'gray')
    overlay_mask_0 = np.ma.masked_where(mask[1] == 0, img)
    overlay_mask_1 = np.ma.masked_where(mask[2] == 0, img)
    overlay_mask_2 = np.ma.masked_where(mask[3] == 0, img)
    plt.imshow(overlay_mask_0, 'Greens', alpha=1, interpolation='nearest')
    plt.imshow(overlay_mask_1, 'Reds', alpha=1, interpolation='nearest')
    plt.imshow(overlay_mask_2, 'Purples', alpha=1, interpolation='nearest')
    plt.show()

# Note

If test on test set and val=False to SegDataset.
If test on validation set and val=True to SegDataset.

In [9]:
val_dataset = SegDataset(data_root='./database/testing', val=False)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)
print(len(val_dataset))

100


In [10]:
# Set up model and optimizer
model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights="imagenet",
        in_channels=1,
        classes=4
    )

# model_path = './trained-model-remco/model/Trying_transforms/Unet-res50_45Rotation/model_best.pth'
model_path = './model/unet-tune/model_best_19.pth'
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()

Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
      

# Note

Change pred_save_dir, gt_save_dir

In [11]:
pred_save_dir = 'prediction/test-pred-unet-res50-vanilla'
gt_save_dir = 'prediction/test-gt-unet-res50-vanilla'
for step, (image_array, gt_image_array, gt_image_path) in enumerate(tqdm(val_loader)):
    gt_image_path = gt_image_path[0]
    image_array = image_array[0]
    gt_image_array = gt_image_array[0]
    img_size = gt_image_array[0].size()
    
    resize_transform = transforms.Resize(img_size)
    prediction = model(image_array.to(device))
    prediction = torch.argmax(prediction, dim=1).detach().cpu()
    prediction = resize_transform(prediction)
    
    
    # Load the ground truth image
    gt_image = sitk.ReadImage(gt_image_path)

    # Convert the numpy array prediction to a SimpleITK image
    pred_image_sitk = sitk.GetImageFromArray(prediction)

    # Use the CopyInformation function
    pred_image_sitk.CopyInformation(gt_image)
    
    base_file_name = os.path.basename(gt_image_path) 

    # replace 'gt' with 'seg' and add '.gz'
    # seg_file_name = base_file_name.replace('_gt.nii', '_seg.nii.gz')
    
    seg_file_name = base_file_name
    
    # Save the image
    sitk.WriteImage(pred_image_sitk, os.path.join(pred_save_dir, seg_file_name))
    sitk.WriteImage(gt_image, os.path.join(gt_save_dir, base_file_name))
#     image_array = resize_transform(image_array)
    
#     one_hot_label = torch.nn.functional.one_hot(gt_image_array.to(torch.int64), num_classes=4)
#     one_hot_label = one_hot_label.permute(0, 3, 1, 2)
    
#     one_hot_pred = torch.nn.functional.one_hot(prediction.to(torch.int64), num_classes=4)
#     one_hot_pred = one_hot_pred.permute(0, 3, 1, 2)
    
#     print(one_hot_label.size(), one_hot_pred.size())
    
#     vis_img(image_array[0][0].numpy(), one_hot_label[0].numpy())
#     vis_img(image_array[0][0].numpy(), one_hot_pred[0].numpy())
#     if step > 0:
#         break

NiftiImageIO (0xa4c49da0): ./database/testing/patient102/patient102_frame13_gt.nii has unexpected scales in sform

NiftiImageIO (0xa4c49da0): ./database/testing/patient102/patient102_frame13_gt.nii has unexpected scales in sform

NiftiImageIO (0x10b475d0): ./database/testing/patient102/patient102_frame13.nii has unexpected scales in sform

NiftiImageIO (0x10b475d0): ./database/testing/patient102/patient102_frame13.nii has unexpected scales in sform

NiftiImageIO (0x166bd5a0): ./database/testing/patient102/patient102_frame13_gt.nii has unexpected scales in sform

NiftiImageIO (0x166bd5a0): ./database/testing/patient102/patient102_frame13_gt.nii has unexpected scales in sform

NiftiImageIO (0x1670b000): ./database/testing/patient102/patient102_frame01_gt.nii has unexpected scales in sform

NiftiImageIO (0x1670b000): ./database/testing/patient102/patient102_frame01_gt.nii has unexpected scales in sform

NiftiImageIO (0x166bd5a0): ./database/testing/patient102/patient102_frame01.nii has un