In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import glob
from imageio import imread
from collections import namedtuple
from sklearn.utils import resample
import cv2

In [None]:
# set random generator seed for reproduction
np.random.seed(77)
tf.random.set_seed(88)

In [None]:
# set the dataset basepath 
basepath = '/home/thalles/Documents/datasets/'

In [None]:
def get_filenames(folder):
    return glob.glob(basepath + folder + '/0/*.png') + glob.glob('/home/thalles/Documents/datasets/' + folder + '/1/*.png')

In [None]:
# read the training examples
training_files = get_filenames('train')
np.random.shuffle(training_files)
print("Number of training examples:", len(training_files))

In [None]:
# read the validation examples
val_files = get_filenames('val')
np.random.shuffle(val_files)
print("Number of validation examples:", len(val_files))

In [None]:
# read the test examples
test_files = get_filenames('test')
np.random.shuffle(test_files)
print("Number of testing examples:", len(test_files))

In [None]:
# plot some of the slide patches
fig, axs = plt.subplots(nrows=5, ncols=5, constrained_layout=False)

for i, ax in enumerate(axs.flat):
    img = cv2.imread(training_files[i], 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    ax.imshow(img)
plt.show()

In [None]:
# The following functions can be used to convert a value to a type compatible
# with tf.Example.
def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def get_tf_feature(image_np, label):
    img_raw = image_np.tostring()

    example = tf.train.Example(features=tf.train.Features(feature={
                        'height': _int64_feature(image_np.shape[0]),
                        'width': _int64_feature(image_np.shape[1]),
                        'depth': _int64_feature(image_np.shape[2]),
                        'image_raw': _bytes_feature(img_raw),
                        'label': _int64_feature(label)}))
    return example

In [None]:
class TFRecordManager:
    def __init__(self, filename):
        self.filename = filename
        
    def __enter__(self):
        self.file = tf.io.TFRecordWriter(self.filename) 
        return self.file
    
    def __exit__(self, exception_type, exception_value, traceback):
        if exception_type:
            print(exception_type, exception_value)
        if self.file:
            self.file.close()

In [None]:
# define the output dataset file
DATASET_DIR= "./tfrecords"

if not os.path.exists(DATASET_DIR):
    os.mkdir(DATASET_DIR)

In [None]:
def get_patch(file, image_size=96):
    # perform the data augmentation strategy used in the project
    img = cv2.imread(file, 1)
    
    if img.shape != (50,50,3):
        return None
    
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img,(image_size, image_size))
    # extract L* and a* from LAB and H and S from HSV
    L,A,_ = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2LAB))
    H,S,_ = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
    
    if img is not None:
        # perform CHAHE normalization
        clahe = cv2.createCLAHE(clipLimit=5, tileGridSize=(16,16))
        planes = cv2.split(img)
        for i in range(0,3):
            planes[i] =clahe.apply(planes[i])
        img = cv2.merge(planes)
        
        # apply a Gaussian blue
        img = cv2.GaussianBlur(img,(5,5),0)

        # merge the final feature vector
        img = cv2.merge([img,L,H,S,A])
        img = np.asarray(img)
                    
    return img

In [None]:
# test get_patch
fig, axs = plt.subplots(nrows=3, ncols=5, constrained_layout=False, figsize=(12,6))

for i, ax in enumerate(axs.flat):
    ax.set_xticks([])
    ax.set_yticks([])
    img = get_patch(training_files[i])
    
    ax.imshow(img[...,:3])
plt.show()

In [None]:
def create_dataset(dataset_name, files, training):
    n_positives = 0
    n_negatives = 0
    stop1 = False
    stop2 = False
    skip = 0
    TRAIN_FILE = dataset_name + '.tfrecords'
    
    with TFRecordManager(os.path.join(DATASET_DIR, TRAIN_FILE)) as writer:
        
        for file in files:
            class_id = int(file[-5])
            
            if class_id == 1: 
                n_positives += 1
            else:
                if training == True:
                    prob = np.random.rand()
                    # ensure the ration of positives to negative samples are 1:1
                    if prob >= 0.3864:
                        continue
                n_negatives += 1
                
            patch = get_patch(file)
    
            if patch is None:
                skip += 1
                continue
            
            example = get_tf_feature(patch, class_id)
            writer.write(example.SerializeToString())

        print(f"Process has finished with a total of {n_negatives} negatives and {n_positives} positive patches.")
        print(f"Skipped {skip} images for UNEXPECTED patch shape")

In [None]:
# each bag contains 50% of the original training data
subsample = 0.5

# create the 3 training bags as tfrecord files
for i in range(3):
    bag_train_files = resample(training_files, replace=True, n_samples=int(subsample*len(training_files)))
    create_dataset('train_bag_' + str(i), bag_train_files, training=True)

In [None]:
# create the validation tfrecord set
create_dataset('val', val_files, training=False)

In [None]:
# create the testing tfrecord set
create_dataset('test', test_files, training=False)