# Create TFRecords for TPU from Image Dataset

In [None]:
import numpy as np
import tensorflow as tf
import os, random
import keras as k
import IPython.display as display
seed = 12

In [None]:
def runSeed():
    global seed
    os.environ['PYTHONHASHSEED']=str(seed)
    tf.random.set_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

runSeed()

## Reading Image Data

In [None]:
basePath = '/kaggle/input/animal-breed-cats-and-dogs/'
trainPath = basePath + 'TRAIN/'
testPath = basePath + 'TEST/'

In [None]:
labels = []
imgPath = []

for x in os.listdir(trainPath):
    currPath = os.path.join(trainPath,x)
    for y in os.listdir(currPath):
        iPath = os.path.join(currPath, y)
        imgPath.append(iPath)
        labels.append(x)

In [None]:
train_size = 0.8
rand = np.random.permutation(len(labels))
train_set_bound = int(len(labels) * train_size)
train_set = rand[:train_set_bound]
valid_set = rand[train_set_bound:]

In [None]:
train_set_data = [imgPath[x] for x in train_set]
train_set_label = [labels[x] for x in train_set]

valid_set_data = [imgPath[x] for x in valid_set]
valid_set_label = [labels[x] for x in valid_set]

In [None]:
test_set_data = []

for x in os.listdir(testPath):
    iPath = os.path.join(testPath,x)
    test_set_data.append(iPath)

## Generating TFRecords

In [None]:
def createTFRecord(tfFilePath, filenames, labels=None):
    tfrecord_writer = tf.io.TFRecordWriter(tfFilePath)
    print("Generating file..",tfFilePath)

    # create a writer
    filenames = tf.constant(filenames)
    if labels:
        labels = tf.constant(labels)
        # iterate over image files in directory
        for img_path, label in zip(filenames, labels):
            try: # try read image file
                raw_file = tf.io.read_file(img_path)
            except FileNotFoundError:
                print("Couldn't read file  {}".format(img_path))
                continue
            # create an example with the image and label
            example = tf.train.Example(features=tf.train.Features(feature={
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_file.numpy()])),
                'label': tf.train.Feature(bytes_list=tf.train.BytesList(value=[label.numpy()]))
            }))
            tfrecord_writer.write(example.SerializeToString()) # write example to file
    else:
        # iterate over image files in directory
        for img_path in filenames:
            try:
                raw_file = tf.io.read_file(img_path)
            except FileNotFoundError:
                print("Couldn't read file  {}".format(img_path))
                continue
            # create an example with the image and label
            example = tf.train.Example(features=tf.train.Features(feature={
                'image': tf.train.Feature(bytes_list=tf.train.BytesList(value=[raw_file.numpy()]))
            }))
            tfrecord_writer.write(example.SerializeToString()) # write example to file

    # close writer
    tfrecord_writer.close()

In [None]:
# generate train tfrecords
createTFRecord('train.tfrecords', train_set_data, train_set_label)
# generate valid tfrecords
createTFRecord('valid.tfrecords', valid_set_data, valid_set_label)
# generate test tfrecords
createTFRecord('test.tfrecords', test_set_data)

## Reading TFRecords

In [None]:
# helper function to read tfrecords file
def tfrecord_reader_fn(filePath):
    return tf.data.TFRecordDataset(filePath)

In [None]:
def _parse_image_label_function(example_proto):
    # Create a dictionary describing the features.
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string),
        'label': tf.io.FixedLenFeature([], tf.string)
    }
    # Parse the input tf.train.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, image_feature_description)

def _parse_image_function(example_proto):
    image_feature_description = {
        'image': tf.io.FixedLenFeature([], tf.string)
    }
    # Parse the input tf.train.Example proto using the dictionary above.
    return tf.io.parse_single_example(example_proto, image_feature_description)

In [None]:
# read and parse tfrecords
train_dataset = tfrecord_reader_fn('train.tfrecords').map(_parse_image_label_function)
valid_dataset = tfrecord_reader_fn('valid.tfrecords').map(_parse_image_label_function)
test_dataset = tfrecord_reader_fn('test.tfrecords').map(_parse_image_function)

In [None]:
# display sample image from train dataset
for image_features in train_dataset:
    image_raw = image_features['image'].numpy()
    image_label = image_features['label'].numpy()
    print(image_label)
    display.display(display.Image(data=image_raw))
    break