## Setup

In [1]:
# from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf

In [2]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
import IPython.display as display
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os
import pathlib

In [None]:
tf.__version__

## Prep data

In [None]:
# set image path
data_dir = pathlib.Path('../data/raw/images')
data_dir

In [None]:
# count size of dataset
image_count = len(list(data_dir.glob('*/*.jpg')))
image_count

In [None]:
# count number of classes
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*')])
CLASS_NAMES

In [None]:
# preview some images from one class
sabre = list(data_dir.glob('sabre/*'))

for image_path in sabre[:3]:
    display.display(Image.open(str(image_path)))

## Load with keras

In [None]:
# The 1./255 is to convert from uint8 to float32 in range [0,1].
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [None]:
# Define loader parameters
BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224
STEPS_PER_EPOCH = np.ceil(image_count/BATCH_SIZE)

In [None]:
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir),
                                                     batch_size=BATCH_SIZE,
                                                     shuffle=True,
                                                     target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                     classes = list(CLASS_NAMES))

In [None]:
# inspect a batch
def show_batch(image_batch, label_batch):
    plt.figure(figsize=(10,10))
    for n in range(25):
        ax = plt.subplot(5,5,n+1)
        plt.imshow(image_batch[n])
        plt.title(CLASS_NAMES[label_batch[n]==1][0].title())
        plt.axis('off')

In [None]:
image_batch, label_batch = next(train_data_gen)
show_batch(image_batch, label_batch)

The above keras.preprocessing method is convienient, but has two downsides:

1. It's slow. See the performance section below.
1. It lacks fine-grained control.
1. It is not well integrated with the rest of TensorFlow.


## Load using tf.data

In [None]:
#create a dataset of the file paths:
list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))
for f in list_ds.take(5):
    print(f.numpy())

In [None]:
# Pure tf function that converts a file paths to an (image_data, label) pair:
def get_label(file_path):
    # convert the path to a list of path components
    parts = tf.strings.split(file_path, os.path.sep)
    # The second to last is the class-directory
    return parts[-2] == CLASS_NAMES

def decode_img(img):
    # convert the compressed string to a 3D uint8 tensor
    img = tf.image.decode_jpeg(img, channels=3)
    # Use `convert_image_dtype` to convert to floats in the [0,1] range.
    img = tf.image.convert_image_dtype(img, tf.float32)
    # resize the image to the desired size.
    return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])

def process_path(file_path):
    label = get_label(file_path)
    # load the raw data from the file as a string
    img = tf.io.read_file(file_path)
    img = decode_img(img)
    return img, label

In [None]:
#Create dataset of image, label pairs
# Set `num_parallel_calls` so multiple images are loaded/processed in parallel.
labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)

for image, label in labeled_ds.take(1):
    print("Image shape: ", image.numpy().shape)
    print("Label: ", label.numpy())