In [1]:
# Process for Training:
# 1. L̶o̶a̶d̶ ̶i̶n̶ ̶i̶m̶a̶g̶e̶ ̶a̶n̶d̶ ̶m̶a̶k̶e̶ ̶i̶t̶ ̶5̶1̶2̶x̶5̶1̶2̶  I will do preprocessing before to make file sizes smaller for training.
# 2. Apply various pertubations/distortions on image
# 3. Clone image, and create the generator copy (this is what is fed to generator)


# The functions defined below are configured such that you input a single image and it will return
# what you need to feed to the generator and what you need to feed to the discriminator

# I start by loading randomly sized images and scale them all to 512x512x3
# The generator image will be cut such that it hides 56 pixels from each side

# In this script, I have two methods for loading in files. The first is through randomly sized images in a directory saved
# as .jpgs. The second is through a TFrecord on pre-processed data to be exactly 512x512x3. The second method is a smaller
# file size so it is easier to train on and can be cached / stored on memory.

In [22]:
import tensorflow as tf
import matplotlib.pyplot as plt

def parse_tfrecord_fn(example_proto):
    feature_description = {
        'height': tf.io.FixedLenFeature([], tf.int64),
        'width': tf.io.FixedLenFeature([], tf.int64),
        'depth': tf.io.FixedLenFeature([], tf.int64),
        'image_raw': tf.io.FixedLenFeature([], tf.string),
    }
    example = tf.io.parse_single_example(example_proto, feature_description)
    image = tf.io.decode_jpeg(example['image_raw'], channels=3)
    return image

# For loading the TfRecord
# I should vectorize my mappings (random jitter / generator image) to work on batches
# 
def load_tfrecord(tfrecord_file, batch_size, buffer=56):
    dataset = tf.data.TFRecordDataset(tfrecord_file)
    dataset = dataset.map(parse_tfrecord_fn, num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.cache() # Store the loaded data into memory
    dataset = dataset.map(random_jitter, num_parallel_calls=tf.data.AUTOTUNE) # Apply tranformations every epoch
    dataset = dataset.map(lambda x: (generator_image(x, buffer), x)) 
    dataset = dataset.batch(batch_size) # Batch every epoch, dont store batches in memory
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE) # prefetch to speed up subsequent computations
    return dataset

def resize(input_image, height, width):
    return tf.image.resize(input_image, [height, width], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)

def random_crop(input_image, height, width):
    return tf.image.random_crop(input_image, size=[height, width, 3])

# Normalizing the images to [-1, 1]
def normalize(input_image):
    input_image = tf.cast(input_image, dtype=tf.float32)
    input_image = (input_image / 127.5) - 1 # want a mean near 0
    return input_image

def denormalize(input_image):
    input_image = (input_image + 1) * 127.5
    input_image = tf.cast(input_image, dtype=tf.int64)
    return input_image

def generator_image(input_image, buffer=56):
    height = len(input_image[0])
    width = len(input_image[1])
    paddings = tf.constant([[buffer, buffer], [buffer, buffer], [0, 0]])
    generator_image = tf.pad(input_image[buffer:(height-buffer), buffer:(width-buffer), :], 
                             paddings, constant_values=0) # constant values should probably cycle between -1 and 1
    return generator_image

@tf.function()
def random_jitter(input_image):
    # Resizing to 542x542
    input_image = resize(input_image, 542, 542)
    # Random cropping back to 512x512
    input_image= random_crop(input_image, 512, 512)
    # Random mirroring
    input_image = tf.image.random_flip_left_right(input_image)
    # Normalizing image
    input_image = normalize(input_image)
    return input_image

# Loading a single jpg
def train_image_load(image_file, buffer=56):
    input_image = load(image_file)
    # Apply various pertubations/distortions on image
    input_image = random_jitter(input_image)
    # Clone image, and create the generator copy
    gen_image = generator_image(input_image, buffer)
    return gen_image, input_image

# Loading a single jpg.
def load(image_file):
    image = tf.io.read_file(image_file)
    image = tf.io.decode_jpeg(image, channels = 3)
    return image

# Note: path must contain final \\
def load_IMGS(path, batch_size):
    dataset = tf.data.Dataset.list_files(str(path+ '*.jpg'))
    dataset = dataset.map(train_image_load,
                                  num_parallel_calls=tf.data.AUTOTUNE)
    dataset = dataset.batch(batch_size)
    return dataset

In [23]:
tfrecord_path = "C:\\Users\\Ilyas\\Desktop\\images.tfrecord"
images_path = "C:\\Users\\Ilyas\\Desktop\\Hot-o-bot\\New folder\\"
BUFFER_SIZE = 400
BATCH_SIZE = 4

dataset1 = load_tfrecord(tfrecord_path, BATCH_SIZE)
dataset2 = load_IMGS(images_path, BATCH_SIZE)


In [26]:
for i in dataset1.take(1):
    print(i[1].shape)

(4, 512, 512, 3)
