## Plot selected segmentation results

This notebook facilitates the plotting of the segmented rings. Please see the
project [README](./README.md) for further details of the process.

In [1]:
import os
import tqdm.notebook
import yaml

import numpy as np
import matplotlib.pyplot as plt
from shellai import preprocessing, tf_util

# load the config file
with open("config.yaml", "r") as fd:
    cfg = yaml.safe_load(fd)
print('Config file loaded')

Config file loaded


In [2]:
# # #  General settings
# directory containing the saved models and their corresponding raw predictions
dir_saved_model = cfg['paths']['local']['model']

# base directory to save things to
dir_base = cfg['paths']['local']['base']

# where we wish to save the plots to. this directory must exist
dir_saved_plots = os.path.join(dir_base, "output", "segmentation_plots")

# patch size
patch_shape = tuple(cfg['training']['patch_shape'])

# select which model configurations (i.e. which epoch(s) we wish to plot for
plotting_epochs = [100, 150, 200]

### Segmentation plots for the leave-one-image-out experiments

In [9]:
# settings for the LOIO experiments

# location of the training images
base_image_dir = os.path.join(dir_base, cfg['images']['train_dir'])

# list of the images used in training. each one corresponds to a model where
# the model was trained on all but the named image.
image_names = cfg['images']['train_images']

# mask of the model predictions
pred_file_mask = 'model_checkpoint_{epoch}_predictions.npz'

# what we want the saved images to be called
saved_image_name = "loo_{image_name}_{epoch}.png"
saved_dpi = 200

In [None]:
for image_name in tqdm.notebook.tqdm(image_names):
    # load the image and ground-truth segmentation
    image, _, ring_mask_full = preprocessing.load_image_data(
        base_image_dir, image_name
    )
    ring_mask = preprocessing.threshold_drawn_mask_and_skeletonize(
        ring_mask_full, sparse=False
    )
    
    gt_mask = original_slice_image = None

    for epoch in plotting_epochs:
        savepath = os.path.join(
            dir_saved_plots,
            saved_image_name.format(image_name=image_name, epoch=epoch)
        )

        if os.path.exists(savepath):
            print(f'Image already saved: {savepath}')
            continue

        pred_path = os.path.join(
            dir_saved_model, image_name, pred_file_mask.format(epoch=epoch)
        )

        # load the predictions, and coordinates of each extracted patch, and
        # their centres
        with np.load(pred_path, allow_pickle=True) as fd:
            predictions = fd['predictions']
            # flc = fd['full_line_coords']
            patch_coords = fd['patch_coords'] # [c0, c1, r0, r1]
            patch_centres = fd['patch_centres'] # [idx0, idx1]

        # only create the ground truth mask once per image
        if gt_mask is None:
            gt_mask = tf_util.place_patches_in_row_image(
                ring_mask, patch_centres, patch_coords, patch_shape
            )

        # likewise for the image itself
        if original_slice_image is None:
            original_slice_image = tf_util.place_patches_in_row_image(
                image, patch_centres, patch_coords, patch_shape
            )

        # create the prediction image
        predicted_mask, _ = tf_util.create_patch_image(
            patches=predictions,
            patch_coords=patch_coords,
            image_shape=image.shape
        )

        # create the image long just the extracted patches
        pred_mask = tf_util.place_patches_in_row_image(
            predicted_mask, patch_centres, patch_coords, patch_shape
        )

        # plot the images
        fig, axes = plt.subplots(
            3, 1, figsize=(15, 2), sharex=True, sharey=True, dpi=saved_dpi
        )
        axes[0].imshow(original_slice_image, aspect='auto')
        axes[1].imshow(gt_mask, aspect='auto')
        axes[2].imshow(pred_mask, aspect='auto')

        axes[0].set_title(
            f"Experiment: Leave-one-image-out - Image: {image_name} - Epochs: {epoch}"
        )

        for i in range(3):
            axes[i].axis('off')
        
        plt.savefig(savepath)
        plt.close()
        
        print(f'Saved: {savepath}')

### Segmentation plots for the final model

In [11]:
# settings for the final model experiment

# location of the real images
base_image_dir = os.path.join(dir_base, cfg['images']['test_dir'])

# list of the images used in training. each one corresponds to a model where
# the model was trained on all but the named image.
image_names = cfg['images']['test_images']

# mask of the model predictions
pred_file_mask = 'model_checkpoint_{epoch}_predictions_{image_name:s}.npz'

# what we want the saved images to be called
saved_image_name = "final_{image_name}_{epoch}.png"
saved_dpi = 200

In [None]:
for image_name in tqdm.notebook.tqdm(image_names):
    # load the image and ground-truth segmentation
    image, _, = preprocessing.load_image_data(
        base_image_dir, image_name, no_rings=True
    )

    original_slice_image = None

    for epoch in plotting_epochs:
        savepath = os.path.join(
            dir_saved_plots,
            saved_image_name.format(image_name=image_name, epoch=epoch)
        )

        if os.path.exists(savepath):
            print(f'Image already saved: {savepath}')
            continue

        pred_path = os.path.join(
            dir_saved_model, 
            "final_model", 
            pred_file_mask.format(image_name=image_name, epoch=epoch)
        )

        # load the predictions, and coordinates of each extracted patch, and
        # their centres
        with np.load(pred_path, allow_pickle=True) as fd:
            predictions = fd['predictions']
            # flc = fd['full_line_coords']
            patch_coords = fd['patch_coords'] # [c0, c1, r0, r1]
            patch_centres = fd['patch_centres'] # [idx0, idx1]

        # only create the slice image once per set of epochs
        if original_slice_image is None:
            original_slice_image = tf_util.place_patches_in_row_image(
                image, patch_centres, patch_coords, patch_shape
            )

        # create the prediction image
        predicted_mask, _ = tf_util.create_patch_image(
            patches=predictions,
            patch_coords=patch_coords,
            image_shape=image.shape
        )

        # create the image long just the extracted patches
        pred_mask = tf_util.place_patches_in_row_image(
            predicted_mask, patch_centres, patch_coords, patch_shape
        )

        # plot the images
        fig, axes = plt.subplots(
            2, 1, figsize=(15, 1.75), sharex=True, sharey=True, dpi=saved_dpi
        )
        axes[0].imshow(original_slice_image, aspect='auto')
        axes[1].imshow(pred_mask, aspect='auto')

        axes[0].set_title(
            f"Experiment: Final model - Image: {image_name} - Epochs: {epoch}"
        )

        for i in range(2):
            axes[i].axis('off')

        plt.savefig(savepath)
        plt.close()
        
        print(f'Saved: {savepath}')