## Extract training images for model training

This notebook processes images for training the convolutional neural networks
to segment clam rings. Please see the project [README](./README.md) for initial
setup of the training data, its file structure and the config file used in all
notebooks, as well as a fuller description of the methodology employed in this
notebook. The notebook uses the training data defined by the variable
`train_dir` in the `images` section of the configuration file. It also ensures
that the files and folders in `train_dir` are structured correctly, i.e., each
image is in its own self-named folder along with a ring segmentation image and
axis of maximum growth image.

Patches are then extracted from each image and its corresponding mask. This is
carried out by extracting patches that overlap across the entire image with
a given stride. In this work we use 256 by 256 pixel patches, with a stride of
64 pixels. Patches that lie wholly outside the convex hull of the hand-drawn
rings are discarded, along with those patches that contain more than 5\%
background pixels (defined as being all black). These settings are specified in
the configuration file. These patches are extracted and stored as
[TFRecord files](https://www.tensorflow.org/tutorials/load_data/tfrecord). This
file type allows for TensorFlow to stream the data from the hard disk while it
is training a model, thereby avoiding having to wait each time to read from
disk.

The TFRecord files are stored in a folder called
`{train_dir}_patches_{ph}x{pw}_stride_{sh}x{sw}`, where `{train_dir}` is the
name of the directory containing the training images, `{ph}` and `{pw}`, and
`{sh}` and `{sw}` are the patch and stride height and width as defined in the
config file. As default, the folder will be named:
`image_train_patches_256x256_stride_64x64`, and, if using the provided training
images, will contain roughly 21gb of files. Please copy the contents of this
folder to a Google Cloud Storage (GCS) bucket and change the corresponding
variables in the configuration file (`project_id` and `bucket_name`) to point
to the account and corresponding bucket. Please see the
[GCS guides](https://cloud.google.com/storage/docs/creating-buckets) for
further information on setting up GCS.

In [1]:
import gc
import os
import yaml
import tqdm.notebook

import numpy as np
import tensorflow as tf

from scipy.spatial.qhull import _Qhull
from shellai import preprocessing, tf_util

# load the config file
with open("config.yaml", "r") as fd:
    cfg = yaml.safe_load(fd)
print('Config file loaded')

Config file loaded


In [3]:
# settings

# Here we are expecting the folder below to contain folders, one per image,
# where each folder contains (at least) three files:
# - fname_image.ext   -- image in some format (jpg or tif)
# - fname_mask.ext    -- mask denoting the line along which to measure rings
# - fname_extended_rings.ext  -- mask of growth rings
base_folder = cfg['images']['train_dir']

# expected filenames
names_expected = ['extended_rings', 'mask', 'image']

# get each image name from the folder
folder_names = []

for name in os.listdir(base_folder):
    path = os.path.join(base_folder, name)
    
    if os.path.isdir(path):
        dirlist = os.listdir(path)

        missing = False

        for expected_filename in names_expected:
            try:
                preprocessing.get_matching_strings_from_list(
                    dirlist, expected_filename, 1
                )
            except ValueError:
                print(
                    f"Missing file for folder '{name:s}': {expected_filename:s}"
                )
                missing = True
            
        if not missing:
            folder_names.append(name)

print(f'Found images: {folder_names}')

Found images: ['gg06007', 'gg110010', 'gg110013', 'gg110026', 'gg110049', 'gg110053', 'gg110055', 'gg140022', 'GY0023', 'gy0030']


In [5]:
# patch settings
patch_shape = tuple(cfg['training']['patch_shape'])
stride = (cfg['training']['stride'], cfg['training']['stride'])

# discard patches with more than this proportion as fully black (background)
empty_proportion = cfg['training']['empty_patch_proportion']

# TFRecord settings -- samples per record file. this should be bigger than
# a single batch so that files can be read while the network is training
# the equation below is (as default) 256 * 128 = 384
num_tfrecord_samples = int(cfg['training']['batch_size'] * 1.5)

# directory to save the TFRecord files to
save_directory = ''.join([
    base_folder,
    f"_patches_{patch_shape[0]}x{patch_shape[1]}",
    f"_stride_{stride[0]}x{stride[1]}"
])

if not os.path.exists(save_directory):
    os.makedirs(save_directory)
    print(f"Created save directory: {save_directory:s}")
else:
    print(f"Saving to: {save_directory:s}")

Saving to: image_train_patches_256x256_stride_64x64


In [None]:
for image_name in tqdm.notebook.tqdm(folder_names):

    # load the image and masks
    image, line_mask, ring_mask = preprocessing.load_image_data(
        base_folder, image_name
    )

    # convert the image masks to binary arrays
    line_mask = preprocessing.threshold_drawn_mask_and_skeletonize(
        line_mask, sparse=False
    )
    ring_mask = preprocessing.threshold_drawn_mask_and_skeletonize(
        ring_mask, sparse=False
    )

    assert isinstance(line_mask, np.ndarray)
    assert isinstance(ring_mask, np.ndarray)

    # calculate the convex hull the mask coords
    ring_mask_coords = np.stack(np.where(ring_mask)).T
    ring_mask_coords = ring_mask_coords.astype('float')

    hull = _Qhull(
        b"i", # type:ignore
        points=ring_mask_coords,
        options=b"",
        furthest_site=False,
        incremental=False,
        interior_point=None,
    )

    # get the indices of each possible patch of the image such that
    # Z[i] = [[c0, r0], [c1, r1]], where the bottom left and upper right
    # coordinates of the patch are [c0, r0] and [c1, r1], and where
    # c = column index and r = row index
    patch_inds = preprocessing.extract_patch_indices(
        image.shape[:2], patch_shape, stride
    )

    # remove patches fully outside the convex hull
    patch_inside_hull_mask = preprocessing.patch_corner_in_hull(
        patch_inds, hull
    )
    patch_inds = patch_inds[patch_inside_hull_mask]

    # remove patches with a large proportion of empty pixels (all black)
    non_empty_mask = preprocessing.remove_empty_patches(
        image, patch_inds, proportion=empty_proportion
    )
    patch_inds = patch_inds[non_empty_mask]

    # extract the patches
    patch_images = preprocessing.extract_patches(image, patch_inds)
    patch_lines = preprocessing.extract_patches(line_mask, patch_inds)
    patch_rings = preprocessing.extract_patches(ring_mask, patch_inds)

    # delete the images from memory
    del (
        image, 
        line_mask, 
        ring_mask, 
        ring_mask_coords,
        non_empty_mask,
        
    )
    gc.collect()

    # turn images into TFRecord files. These are designed to be quick to open
    # in tensorflow, and can be loaded while the model is taking a step,
    # therefore somewhat avoiding (in theory!) bottlenecks due to hardware IO
    total_size = patch_images.shape[0]
    num_tfrecords = (total_size // num_tfrecord_samples) + 1

    for tfrec_num in range(num_tfrecords):
        # work out the next start/end image indices to store in the record
        start = tfrec_num * num_tfrecord_samples
        end = min(start + num_tfrecord_samples, total_size)

        samples = range(start, end)

        save_filename = (
            f"{image_name}_{tfrec_num:03d}_{len(samples):d}.tfrec"
        )

        with tf.io.TFRecordWriter(
            os.path.join(save_directory, save_filename),
        ) as writer:

            # for each image going into the records
            for i in tqdm.notebook.tqdm(samples, leave=False):

                # turn it into a record
                feature = tf_util.create_tfrecord_feature(
                    patch_images[i],
                    patch_rings[i],
                    patch_lines[i],
                    patch_inds[i],
                    image_name,
                )

                # turn it into an 'example'
                example = tf.train.Example(
                    features=tf.train.Features(feature=feature)
                )
                
                # write to the record
                writer.write(example.SerializeToString()) # type:ignore

        print(f"Saved tf records: {save_filename:s}")