In [1]:
import torch
from torchvision import transforms
from utils import mypreprocess, util_functions, eff_unet
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from collections import defaultdict
from mcdropout import MCDropout2D
import numpy as np
from tqdm import tqdm
import cv2
import nibabel as nib
from skimage import io

In [32]:
model_path = '/work/ovens_lab/thaonguyen/image_segmentation/best_model.pt' 
model = eff_unet.EffUNet(in_channels=1, classes=1)
model.load_state_dict(torch.load(model_path))
model.eval() 

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [33]:
IMG_SIZE = 512
inference_dir = '/work/ovens_lab/thaonguyen/image_segmentation/kits_2d_dataset_latest5/testing' 
inference_transformer = transforms.Compose([transforms.ToTensor(), 
                                            transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)])
inference_loader = mypreprocess.create_data_loaders(path_dir=inference_dir, image_dir='images', 
                                                    label_dir='labels', 
                                                    data_transformer=inference_transformer, 
                                                    batch_size=1, split_size=None)

No of batches: 1619, batch shape: torch.Size([1, 1, 512, 512])


In [34]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [35]:
def dice_coefficient2_modified(target, preds):
    preds_flat = preds.view(-1)
    target_flat = target.view(-1)

    intersection = (preds_flat * target_flat).sum()
    set_sum = preds_flat.sum() + target_flat.sum()

    dice = (2 * intersection + 1e-8) / (set_sum + 1e-8) 

    return dice

In [36]:
# Save predictions together with original image and ground truth
def show_image_modified(image, mask, pred_image=None, path_dir=None, filename=None, iou=None, dice_score=None):
    image = image.squeeze() 
    mask = mask.squeeze() 

    plt.figure(figsize=(10, 5))

    if pred_image is not None:
        ax1 = plt.subplot(1, 3, 1)
        ax2 = plt.subplot(1, 3, 2)
        ax3 = plt.subplot(1, 3, 3)
    else:
        ax1 = plt.subplot(1, 2, 1)
        ax2 = plt.subplot(1, 2, 2)

    ax1.set_title('IMAGE')
    ax1.imshow(image, cmap='gray')
    ax1.axis('off')

    ax2.set_title('GROUND TRUTH')
    ax2.imshow(mask, cmap='gray')
    ax2.axis('off')

    if pred_image is not None and iou is not None and dice_score is not None:
        pred_image = pred_image.squeeze()
        ax3.imshow(pred_image, cmap='gray')
        ax3.set_title('MODEL OUTPUT')
        ax3.axis('off')
        ax3.text(5, 5, f'IoU: {iou:.2f}, Dice: {dice_score:.2f}', color='white', fontsize=8, backgroundcolor='black')

    if path_dir is not None and filename is not None:
        full_path = os.path.join(path_dir, filename)
        plt.savefig(full_path, bbox_inches='tight', pad_inches=0)

    plt.close()

In [37]:
# Save uncertainty map together with the original image
def show_uncertainty_map(original_image, uncertainty_map, path, filename, avg_iou, avg_dice):
    plt.figure(figsize=(10, 5))

    ax1 = plt.subplot(1, 2, 1)
    ax1.imshow(original_image, cmap='gray')
    ax1.set_title('Original Image')
    ax1.axis('off')

    ax2 = plt.subplot(1, 2, 2)
    ax2.imshow(uncertainty_map, cmap='gray')
    ax2.set_title('Uncertainty Map')
    ax2.axis('off')

    plt.figtext(0.5, 0.01, f"Average IoU: {avg_iou:.4f}, Average Dice: {avg_dice:.4f}", ha="center", fontsize=12)

    combined_save_path = os.path.join(path, filename)
    plt.savefig(combined_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [38]:
# Save the prediction only
def save_prediction(pred_image, path_dir, filename):
    if pred_image is not None:
        plt.figure(figsize=(512 / 100, 512 / 100)) 
        plt.imshow(pred_image.squeeze(), cmap='gray') 
        plt.axis('off')

        if path_dir is not None and filename is not None:
            full_path = os.path.join(path_dir, filename)
            plt.savefig(full_path, bbox_inches='tight', pad_inches=0)
            plt.close()


In [39]:
# Save the uncertainty map only
def save_uncertainty_map(uncertainty_map, path, filename):
    plt.figure(figsize=(512 / 100, 512 / 100)) 

    plt.imshow(uncertainty_map, cmap='gray')
    plt.axis('off')
    
    combined_save_path = os.path.join(path, filename)
    plt.savefig(combined_save_path, bbox_inches='tight', pad_inches=0)
    plt.close()

In [40]:
# Activate MC Dropout for inference
MCDropout2D.activate()

In [41]:
def calculate_entropy(prob_maps):
    return -np.sum(prob_maps * np.log(prob_maps + 1e-8), axis=0)

In [42]:
num_passes = 20

In [43]:
def calculate_iou_dice_per_pass(pred_mask, true_mask):
    iou = util_functions.calculate_IoU(pred_mask, true_mask)
    dice_score = dice_coefficient2_modified(true_mask, pred_mask)
    return iou.item(), dice_score.item()

In [45]:
# Save the prediction and uncertainty map only
mc_prediction_folder = '/work/ovens_lab/thaonguyen/image_segmentation/kits_mc_predictions_updated'
uncertainty_map_folder = '/work/ovens_lab/thaonguyen/image_segmentation/kits_uncertainty_map_updated'
os.makedirs(mc_prediction_folder, exist_ok=True)
os.makedirs(uncertainty_map_folder, exist_ok=True)

total_iou = 0.0
total_dice = 0.0
num_images = 0

for idx, (images, true_masks, index_tuple) in tqdm(enumerate(inference_loader), total=len(inference_loader.dataset)):
    images = images.to(device)
    true_masks = true_masks.to(device)
    image_index, slice_index = map(lambda x: x.item(), index_tuple)

    prediction_filename = f'image_{image_index:04d}_{slice_index:04d}.png'
    uncertainty_filename = f'image_{image_index:04d}_{slice_index:04d}_uncertainty.png'

    uncertainty_file_path = os.path.join(uncertainty_map_folder, uncertainty_filename)
    if os.path.exists(uncertainty_file_path):
        continue 

    pass_iou = 0.0
    pass_dice = 0.0
    prob_maps = []

    for pass_idx in range(num_passes):
        with torch.no_grad():
            logits_mask = model(images)
            pred_mask = torch.sigmoid(logits_mask)
            pred_mask_cpu = pred_mask[0].detach().cpu().squeeze().numpy()
            prob_maps.append(pred_mask_cpu)

            iou, dice_score = calculate_iou_dice_per_pass(pred_mask, true_masks)
            pass_iou += iou
            pass_dice += dice_score

            pass_folder = os.path.join(mc_prediction_folder, f'pass{pass_idx+1}')
            os.makedirs(pass_folder, exist_ok=True)
            save_path = os.path.join(pass_folder, prediction_filename)
            save_prediction(pred_mask_cpu, pass_folder, prediction_filename) 

    avg_iou = pass_iou / num_passes
    avg_dice = pass_dice / num_passes

    total_iou += avg_iou
    total_dice += avg_dice
    num_images += 1

    prob_maps = np.stack(prob_maps, axis=0)
    uncertainty_map = calculate_entropy(prob_maps)
    original_image = images[0].detach().cpu().squeeze().numpy()
    save_uncertainty_map(uncertainty_map, uncertainty_map_folder, uncertainty_filename)  

overall_iou = total_iou / num_images
overall_dice = total_dice / num_images
print(f"Overall IoU: {overall_iou:.4f}, Overall Dice Score: {overall_dice:.4f}")


100%|██████████| 1619/1619 [27:13<00:00,  1.01s/it]

Overall IoU: nan, Overall Dice Score: 0.3834





In [47]:
import os
import numpy as np
import nibabel as nib
import imageio as io
from tqdm import tqdm
import logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s:%(message)s')

def reverse_transform_slice(slice_data):
    # Rotate the slice by 90 degrees to the right
    unrotated_slice = np.rot90(slice_data, k=1, axes=(0, 1))
    return unrotated_slice

def load_and_combine_uncertainty_slices_to_nifti(uncertainty_map_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Dictionary to hold uncertainty slices for each image
    image_uncertainty_slices = {}

    # Iterate over uncertainty map files and group slices by image
    for file in tqdm(sorted(os.listdir(uncertainty_map_folder))):
        if file.endswith('_uncertainty.png'):
            parts = file.split('_')
            image_index = int(parts[1])
            slice_index = int(parts[2])
            slice_path = os.path.join(uncertainty_map_folder, file)
            try:
                slice_data = io.imread(slice_path, mode='F')
                # Reverse transformations on the slice
                #slice_data = reverse_transform_slice(slice_data)

                if image_index not in image_uncertainty_slices:
                    image_uncertainty_slices[image_index] = {}
                image_uncertainty_slices[image_index][slice_index] = slice_data
            except Exception as e:
                logging.error(f"Error loading slice {slice_path}: {e}")

    # Process each image group
    for image_index, slices_dict in image_uncertainty_slices.items():
        try:
            # Ensure all slices are present and in order
            max_slice_index = max(slices_dict.keys())
            slices = []
            for i in range(max_slice_index + 1):
                if i in slices_dict:
                    slices.append(slices_dict[i])
                else:
                    logging.warning(f"Missing slice index {i} for image {image_index}. Using empty slice.")
                    empty_slice = np.zeros_like(list(slices_dict.values())[0])
                    slices.append(empty_slice)

            # Reverse the flip transformations on the 3D image
            image_3d_uncertainty = np.stack(slices, axis=-1)
            image_3d_uncertainty = np.flip(image_3d_uncertainty, axis=2)  # reverse z-axis flip
            #image_3d_uncertainty = np.flip(image_3d_uncertainty, axis=1)  # reverse y-axis flip

            nifti_path = os.path.join(output_folder, f'image_{image_index:04d}_uncertainty.nii.gz')
            nifti_img = nib.Nifti1Image(image_3d_uncertainty, np.eye(4))
            nib.save(nifti_img, nifti_path)
        except Exception as e:
            logging.error(f"Error processing image {image_index}: {e}")

    print("3D uncertainty NIfTI files created.")

uncertainty_map_folder = '/work/ovens_lab/thaonguyen/image_segmentation/kits_uncertainty_map_updated'
output_folder = '/work/ovens_lab/thaonguyen/image_segmentation/3d_kits_uncertainty_map_updated'
load_and_combine_uncertainty_slices_to_nifti(uncertainty_map_folder, output_folder)

  slice_data = io.imread(slice_path, mode='F')
100%|██████████| 1619/1619 [00:05<00:00, 303.30it/s]


3D uncertainty NIfTI files created.
