In [1]:
# Imports
import pathlib

import numpy as np
import torch
from skimage.io import imread
from skimage.transform import resize

from inference import predict
from transformations import normalize_01, re_normalize
from unet import UNet

import pickle


# root directory
root = pathlib.Path.cwd() / "Data" / "2018"
root_temp = pathlib.Path.cwd() / "temp_chkp"

# load and process images from original images or load from saved pickle files
USE_SAVED_IMG = True


def get_filenames_of_path(path: pathlib.Path, ext: str = "*"):
    """Returns a list of files in a directory/path. Uses pathlib."""
    filenames = [file for file in path.glob(ext) if file.is_file()]
    return filenames


# input and target files
images_names = get_filenames_of_path(root / "ISIC2018_Task1-2_Validation_Input", ext='*.jpg')
targets_names = get_filenames_of_path(root / "ISIC2018_Task1_Validation_GroundTruth", ext='*.png')

# load data from saved files
if USE_SAVED_IMG:
    with open(root_temp / "test_images.pkl", "rb") as f:
        images_res = pickle.load(f)
    with open(root_temp / "test_targets.pkl", "rb") as f:
        targets_res = pickle.load(f)    

else:
    # read images and store them in memory
    images = [imread(img_name) for img_name in images_names]
    targets = [imread(tar_name) for tar_name in targets_names]

    # Resize images and targets
    images_res = [resize(img, (128, 128, 3)) for img in images]
    resize_kwargs = {"order": 0, "anti_aliasing": False, "preserve_range": True}
    targets_res = [resize(tar, (128, 128), **resize_kwargs) for tar in targets]

    # change target label to show differetn color from prediction result
    targets_res = [np.where(target==255, 1, target).astype(int) for target in targets_res]

    # save images and targets
    with open(root_temp / "test_images.pkl", "wb") as f:
        pickle.dump(images_res, f)
    with open(root_temp / "test_targets.pkl", "wb") as f:
        pickle.dump(targets_res, f)


In [2]:
# device
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    torch.device("cpu")

# model
model = UNet(
    in_channels=3,
    out_channels=2,
    n_blocks=4,
    start_filters=32,
    activation="relu",
    normalization="batch",
    conv_mode="same",
    dim=2,
).to(device)


model_name = "test.pt"
model_weights = torch.load(pathlib.Path.cwd() / "Output" / model_name)

model.load_state_dict(model_weights)


<All keys matched successfully>

In [3]:
# preprocess function
def preprocess(img: np.ndarray):
    img = np.moveaxis(img, -1, 0)  # from [H, W, C] to [C, H, W]
    img = normalize_01(img)  # linear scaling to range [0-1]
    img = np.expand_dims(img, axis=0)  # add batch dimension [B, C, H, W]
    img = img.astype(np.float32)  # typecasting to float32
    return img


# postprocess function
def postprocess(img: torch.tensor):
    img = torch.argmax(img, dim=1)  # perform argmax to generate 1 channel
    img = img.cpu().numpy()  # send to cpu and transform to numpy.ndarray
    img = np.squeeze(img)  # remove batch dim and channel dim -> [H, W]
    img = re_normalize(img)  # scale it to the range [0-255]
    return img


In [4]:
# predict the segmentation maps
output = [predict(img, model, preprocess, postprocess, device) for img in images_res]


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


In [5]:
from visual import enable_gui_qt

enable_gui_qt()
import napari

viewer = napari.Viewer()

idx = 32
img_nap = viewer.add_image(images_res[idx], name="Input")
tar_nap = viewer.add_labels(targets_res[idx], name="Target")
out_nap = viewer.add_labels(output[idx], name="Prediction")


  zoom = np.min(canvas_size / scale)
