In [1]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
from PIL import Image
import numpy as np

from unet import UNet
from dataset import SegmentationDataset

In [2]:
# Define the data transforms
transform = transforms.Compose([
    transforms.ToTensor(),         # Convert the images to PyTorch tensors
    transforms.Grayscale()
])

In [3]:
# Define the data directories
data_dir = './data'
test_dir = data_dir + '/test'
test_data = SegmentationDataset(test_dir, transform=transform)
test_loader = DataLoader(test_data, batch_size=1)

In [4]:
model = UNet(n_channels=1, n_classes=1).to('cpu')
model.load_state_dict(torch.load("models/unet_best_epoch.pth"))
model.eval()

UNet(
  (inc): DoubleConv(
    (double_conv): Sequential(
      (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (5): ReLU(inplace=True)
    )
  )
  (down1): Down(
    (maxpool_conv): Sequential(
      (0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (1): DoubleConv(
        (double_conv): Sequential(
          (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
          (3): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
 

In [5]:
def dice_coefficient(outputs, labels):
    smooth = 1.
    num = outputs.size(0)
    outputs = outputs.view(num, -1)
    labels = labels.view(num, -1)
    intersection = (outputs * labels).sum(1)
    dice = (2. * intersection + smooth) / (outputs.sum(1) + labels.sum(1) + smooth)
    return dice.mean()

In [6]:
for i, (image, label) in enumerate(test_loader):
    predicted_mask = model(image)
    
    dice_coeff = dice_coefficient(predicted_mask, label)
    print(dice_coeff.item() * 100, "%")

    arr = predicted_mask.detach().numpy()
    arr = arr.reshape(256, 256)
    arr_mapped = np.interp(arr, (arr.min(), arr.max()), (255, 0))

    label = label.detach().numpy()
    label = label.reshape(256, 256) 
    arr_concat = np.concatenate((label, arr_mapped), axis=1)
    image = Image.fromarray(arr_concat.astype('uint8')*255, mode='L')
    image.save(f'prediction/pred_label_concat_{i}.png')

99.80581998825073 %
