In [1]:
import torch
from torchvision import transforms
from utils import mypreprocess, util_functions, eff_unet
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt
from collections import defaultdict

In [2]:
model_path = '/work/ovens_lab/thaonguyen/image_segmentation/best_model.pt' 
model = eff_unet.EffUNet(in_channels=1, classes=1)
model.load_state_dict(torch.load(model_path))
model.eval() 

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [3]:
IMG_SIZE = 512
inference_dir = '/work/ovens_lab/thaonguyen/image_segmentation/2d_dataset/testing' 
inference_transformer = transforms.Compose([transforms.ToTensor(), 
                                            transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)])
inference_loader = mypreprocess.create_data_loaders(path_dir=inference_dir, image_dir='images', 
                                                    label_dir='labels', 
                                                    data_transformer=inference_transformer, 
                                                    batch_size=1, split_size=None)

No of batches: 941, batch shape: torch.Size([1, 1, 512, 512])


In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [5]:
def dice_coefficient2_modified(target, preds):
    preds_flat = preds.view(-1)
    target_flat = target.view(-1)

    intersection = (preds_flat * target_flat).sum()
    set_sum = preds_flat.sum() + target_flat.sum()

    dice = (2 * intersection + 1e-8) / (set_sum + 1e-8) 

    return dice

In [6]:
def show_image_modified(image, mask, pred_image=None, path_dir=None, filename=None, iou=None, dice_score=None):
    image = image.squeeze()  # Remove channel dimension if it's present
    mask = mask.squeeze()  # Remove channel dimension if it's present

    plt.figure(figsize=(10, 5))

    if pred_image is not None:
        ax1 = plt.subplot(1, 3, 1)
        ax2 = plt.subplot(1, 3, 2)
        ax3 = plt.subplot(1, 3, 3)
    else:
        ax1 = plt.subplot(1, 2, 1)
        ax2 = plt.subplot(1, 2, 2)

    ax1.set_title('IMAGE')
    ax1.imshow(image, cmap='gray')
    ax1.axis('off')

    ax2.set_title('GROUND TRUTH')
    ax2.imshow(mask, cmap='gray')
    ax2.axis('off')

    # Add IoU and Dice score to the predicted image subplot
    if pred_image is not None and iou is not None and dice_score is not None:
        pred_image = pred_image.squeeze()  # Remove channel dimension if it's present
        ax3.imshow(pred_image, cmap='gray')
        ax3.set_title('MODEL OUTPUT')
        ax3.axis('off')
        ax3.text(5, 5, f'IoU: {iou:.2f}, Dice: {dice_score:.2f}', color='white', fontsize=8, backgroundcolor='black')

    if path_dir is not None and filename is not None:
        full_path = os.path.join(path_dir, filename)
        plt.savefig(full_path, bbox_inches='tight', pad_inches=0)

    plt.close()

In [8]:
# Save predictions with the original image and ground truth

# from tqdm import tqdm

# total_iou = 0.0
# total_dice = 0.0
# num_images = 0

# # Define the 2D predictions folder
# predictions_folder = '/work/ovens_lab/thaonguyen/image_segmentation/2d_results'
# os.makedirs(predictions_folder, exist_ok=True)

# # Wrap the inference loader with tqdm for a progress bar
# for idx, (images, true_masks, index_tuple) in tqdm(enumerate(inference_loader), total=len(inference_loader)):
#     images = images.to(device)
#     true_masks = true_masks.to(device)

#     # Extract and convert image index and slice index to integers
#     image_index, slice_index = index_tuple
#     image_index = image_index.item()
#     slice_index = slice_index.item()

#     with torch.no_grad():
#         logits_mask = model(images)
#         pred_mask = torch.sigmoid(logits_mask)
#         pred_mask = (pred_mask > 0.5) * 1.0

#     # Calculate IoU and Dice Score for each image
#     iou = util_functions.calculate_IoU(pred_mask, true_masks)
#     dice_score = dice_coefficient2_modified(true_masks, pred_mask)

#     total_iou += iou.item()
#     total_dice += dice_score
#     num_images += 1

#     # Move tensors to CPU and format filename
#     images_cpu = images[0].detach().cpu().squeeze()
#     true_masks_cpu = true_masks[0].detach().cpu().squeeze()
#     pred_mask_cpu = pred_mask[0].detach().cpu().squeeze()
#     save_filename = f'image_{image_index:04d}_{slice_index:04d}.png'

#     # Save the prediction in the 2d_predictions folder
#     save_path = os.path.join(predictions_folder, save_filename)
#     show_image_modified(images_cpu, true_masks_cpu, pred_mask_cpu, path_dir=predictions_folder, filename=save_filename, iou=iou, dice_score=dice_score)

# # Calculate and print the overall IoU and Dice Score
# overall_iou = total_iou / num_images
# overall_dice = total_dice / num_images
# print(f"Overall IoU: {overall_iou:.2f}, Overall Dice Score: {overall_dice:.2f}")


In [13]:
# Save predictions only

from tqdm import tqdm
import matplotlib.pyplot as plt
import cv2

def save_prediction(prediction, path, filename, output_size=(512, 512)):
    """Save the prediction mask as a PNG file with a specific size."""
    # Resize prediction to the desired output size
    resized_prediction = cv2.resize(prediction, output_size, interpolation=cv2.INTER_NEAREST)

    plt.figure(figsize=(6, 6))
    plt.imshow(resized_prediction, cmap='gray')
    plt.axis('off')
    plt.savefig(os.path.join(path, filename), bbox_inches='tight', pad_inches=0)
    plt.close()

total_iou = 0.0
total_dice = 0.0
num_images = 0

# Define the 2D predictions folder
predictions_folder = '/work/ovens_lab/thaonguyen/image_segmentation/2d_predictions'
os.makedirs(predictions_folder, exist_ok=True)

# Process each image and mask in the dataset
for idx, (images, true_masks, index_tuple) in tqdm(enumerate(inference_loader), total=len(inference_loader)):
    images = images.to(device)
    true_masks = true_masks.to(device)
    image_index, slice_index = map(lambda x: x.item(), index_tuple)

    # Run inference
    with torch.no_grad():
        logits_mask = model(images)
        pred_mask = torch.sigmoid(logits_mask)
        # Ensure we are extracting a single 2D slice
        pred_mask_cpu = (pred_mask > 0.5).float().cpu().numpy()[0, 0, :, :]

    # Calculate IoU and Dice Score
    iou = util_functions.calculate_IoU(pred_mask, true_masks)
    dice_score = dice_coefficient2_modified(true_masks, pred_mask)

    total_iou += iou.item()
    total_dice += dice_score
    num_images += 1

    # Save the prediction mask
    save_filename = f'image_{image_index:04d}_{slice_index:04d}.png'
    save_prediction(pred_mask_cpu, predictions_folder, save_filename)

# Calculate and print the overall IoU and Dice Score
overall_iou = total_iou / num_images
overall_dice = total_dice / num_images
print(f"Overall IoU: {overall_iou:.2f}, Overall Dice Score: {overall_dice:.2f}")

100%|██████████| 941/941 [01:28<00:00, 10.61it/s]

Overall IoU: nan, Overall Dice Score: 0.46





In [14]:
import os
import numpy as np
import nibabel as nib
from skimage import io
from tqdm import tqdm

def load_and_combine_slices_to_nifti(png_folder, output_folder):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    # Dictionary to hold slices for each image
    image_slices = {}

    # Iterate over PNG files and group slices by image
    for file in sorted(os.listdir(png_folder)):
        if file.endswith('.png'):
            parts = file.split('_')
            image_index = int(parts[1])
            slice_index = int(parts[2].split('.')[0])
            slice_path = os.path.join(png_folder, file)
            slice_data = io.imread(slice_path, as_gray=True)

            if image_index not in image_slices:
                image_slices[image_index] = {}
            image_slices[image_index][slice_index] = slice_data

    # Process each image group
    for image_index in image_slices:
        slices = [image_slices[image_index][i] for i in sorted(image_slices[image_index])]
        image_3d = np.stack(slices, axis=-1)

        # Save as NIfTI file
        nifti_path = os.path.join(output_folder, f'image_{image_index:04d}.nii.gz')
        nifti_img = nib.Nifti1Image(image_3d, np.eye(4))
        nib.save(nifti_img, nifti_path)

    print("3D NIfTI files created.")

# Example usage
png_folder = '/work/ovens_lab/thaonguyen/image_segmentation/2d_predictions'
output_folder = '/work/ovens_lab/thaonguyen/image_segmentation/3d_predictions'
load_and_combine_slices_to_nifti(png_folder, output_folder)

3D NIfTI files created.
