In [1]:
import torch
from torchvision import transforms
from utils import mypreprocess, util_functions, eff_unet, unet
from torch.utils.data import DataLoader
import os
import matplotlib.pyplot as plt

In [2]:
# model_path = '/work/ovens_lab/thaonguyen/image_segmentation/best_model.pt' 
model_path = 'results/2023-11-15_15-32/best_model.pt' 
model = eff_unet.EffUNet(in_channels=1, classes=1)
# model = unet.UNet(num_classes=1, input_channels=1)
model.load_state_dict(torch.load(model_path, map_location='cpu'))
model.eval() 

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [3]:
IMG_SIZE = 512
# inference_dir = '/work/ovens_lab/thaonguyen/image_segmentation/2d_dataset/testing' 
inference_dir = '/scratch/student/sinaziaee/datasets/2d_dataset/testing'
inference_transformer = transforms.Compose([transforms.ToTensor(), 
                                            transforms.Resize((IMG_SIZE, IMG_SIZE), antialias=True)])
inference_loader = mypreprocess.create_data_loaders(path_dir=inference_dir, image_dir='images', 
                                                    label_dir='labels', 
                                                    data_transformer=inference_transformer, 
                                                    batch_size=1, split_size=None)

dataset info: 
 No images: 941, No masks: 941, 
 No of batches: 941, batch shape: torch.Size([1, 1, 512, 512])


In [4]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model.to(device)

EffUNet(
  (start_conv): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
  )
  (down_block_2): Sequential(
    (0): MBConvBlock(
      (depthwise_conv): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
      )
      (se_block): Sequential(
        (0): AdaptiveAvgPool2d(output_size=(1, 1))
        (1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): ReLU()
        (3): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (4): Sigmoid()
      )
      (pointwise_conv): Sequential(
        (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True,

In [5]:
print(device)

cuda:0


In [6]:
results_folder = 'inference_results'
os.makedirs(results_folder, exist_ok=True)

In [7]:
def dice_coefficient2_modified(target, preds):
    preds_flat = preds.view(-1)
    target_flat = target.view(-1)

    intersection = (preds_flat * target_flat).sum()
    set_sum = preds_flat.sum() + target_flat.sum()

    dice = (2 * intersection + 1e-8) / (set_sum + 1e-8) 

    return dice

In [8]:
def show_image_modified(image, mask, pred_image=None, path_dir=None, filename=None, iou=None, dice_score=None):
    image = image.squeeze()  # Remove channel dimension if it's present
    mask = mask.squeeze()  # Remove channel dimension if it's present

    plt.figure(figsize=(10, 5))

    if pred_image is not None:
        ax1 = plt.subplot(1, 3, 1)
        ax2 = plt.subplot(1, 3, 2)
        ax3 = plt.subplot(1, 3, 3)
    else:
        ax1 = plt.subplot(1, 2, 1)
        ax2 = plt.subplot(1, 2, 2)

    ax1.set_title('IMAGE')
    ax1.imshow(image, cmap='gray')
    ax1.axis('off')

    ax2.set_title('GROUND TRUTH')
    ax2.imshow(mask, cmap='gray')
    ax2.axis('off')

    # Add IoU and Dice score to the predicted image subplot
    if pred_image is not None and iou is not None and dice_score is not None:
        pred_image = pred_image.squeeze()  # Remove channel dimension if it's present
        ax3.imshow(pred_image, cmap='gray')
        ax3.set_title('MODEL OUTPUT')
        ax3.axis('off')
        ax3.text(5, 5, f'IoU: {iou:.2f}, Dice: {dice_score:.2f}', color='white', fontsize=8, backgroundcolor='black')

    if path_dir is not None and filename is not None:
        full_path = os.path.join(path_dir, filename)
        plt.savefig(full_path, bbox_inches='tight', pad_inches=0)

    plt.close()

In [9]:
total_iou = 0.0
total_dice = 0.0
num_images = 0

for idx, (images, true_masks) in enumerate(inference_loader):
    images = images.to(device)
    true_masks = true_masks.to(device)
    
    filename = inference_loader.dataset.get_filename(idx)

    with torch.no_grad():
        logits_mask = model(images)
        pred_mask = torch.sigmoid(logits_mask)
        pred_mask = (pred_mask > 0.5) * 1.0

    # Calculate IoU and Dice Score for each image
    iou = util_functions.calculate_IoU(pred_mask, true_masks)
    dice_score = dice_coefficient2_modified(true_masks, pred_mask)

    total_iou += iou.item()
    total_dice += dice_score
    num_images += 1

    # Move tensors to CPU
    images_cpu = images[0].detach().cpu().squeeze()
    true_masks_cpu = true_masks[0].detach().cpu().squeeze()
    pred_mask_cpu = pred_mask[0].detach().cpu().squeeze()

    # Use the modified show_image function
    show_image_modified(images_cpu, true_masks_cpu, pred_mask_cpu, 
            path_dir=results_folder, filename=filename, iou=iou, dice_score=dice_score)

overall_iou = total_iou / num_images
overall_dice = total_dice / num_images
print(f"Overall IoU: {overall_iou:.2f}, Overall Dice Score: {overall_dice:.2f}")

Overall IoU: 0.93, Overall Dice Score: 0.95
