Purpose of this notebook is to enable creation of TFRecords which can be used to train models created using Tensorflow on TPU.

This notebook is a copy of Chris Deotte's notebook https://www.kaggle.com/cdeotte/how-to-create-tfrecords with slight modifications

In this notebook Image Size is 256 X 256. By Changing RESIZE and IMAGE_SIZE Parameters, same notebook can be used to generate TFRecords of any other size.

If You have any feedback , please feel free to comment

You will notice few warnings/messages which dont prevent the creation of TFRecord. If anyone has any idea on how to eliminate them, please comment.


In [None]:
import numpy as np, pandas as pd, os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'

import matplotlib.pyplot as plt, cv2
import tensorflow as tf, re, math
from tqdm.notebook import tqdm

In [None]:
RESIZE = 256
IMAGE_SIZE= [256,256] 
BATCH_SIZE = 32

In [None]:
PATH = '../input/petfinder-pawpularity-score/train/'
IMGS = os.listdir(PATH)

print('There are %i train images '%(len(IMGS)))

In [None]:
df = pd.read_csv('../input/petfinder-pawpularity-score/train.csv')
df.rename({'Id':'image_name'},axis=1,inplace=True)
df.head()

## Write TFRecords - Train

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))):
        value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
    """Returns a float_list from a float / double."""
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
    """Returns an int64_list from a bool / enum / int / uint."""
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

In [None]:
def serialize_example(feature0, feature1, feature2, feature3, feature4, feature5, feature6, feature7, feature8, feature9, feature10, feature11, feature12, feature13, feature14):
    feature = {
      'image': _bytes_feature(feature0),
      'image_name': _bytes_feature(feature1),
      'Subject Focus': _int64_feature(feature2),
      'Eyes': _int64_feature(feature3),
      'Face': _int64_feature(feature4),
      'Near': _int64_feature(feature5),
      'Action': _int64_feature(feature6),
      'Accessory': _int64_feature(feature7),
      'Group': _int64_feature(feature8),
      'Collage': _int64_feature(feature9),
      'Human': _int64_feature(feature10),
      'Occlusion': _int64_feature(feature11), 
      'Info': _int64_feature(feature12),
      'Blur': _int64_feature(feature13),
      'Pawpularity': _int64_feature(feature14),
  }
    example_proto = tf.train.Example(features=tf.train.Features(feature=feature))
    return example_proto.SerializeToString()

In [None]:
SIZE = 413
CT = len(IMGS)//SIZE + int(len(IMGS)%SIZE!=0)
for j in range(CT):
    print(); print('Writing TFRecord %i of %i...'%(j,CT))
    CT2 = min(SIZE,len(IMGS)-j*SIZE)
    with tf.io.TFRecordWriter('train%.2i-%i.tfrec'%(j,CT2)) as writer:
        for k in range(CT2):
            img = cv2.imread(PATH+IMGS[SIZE*j+k])
            img = cv2.resize(img, (RESIZE,RESIZE))
            img = cv2.imencode('.jpg', img, (cv2.IMWRITE_JPEG_QUALITY, 95))[1].tobytes()
            name = IMGS[SIZE*j+k].split('.')[0]
            row = df.loc[df.image_name==name]
            example = serialize_example(
                img, str.encode(name),
                row['Subject Focus'].values[0],
                row['Eyes'].values[0], 
                row['Face'].values[0], 
                row['Near'].values[0], 
                row['Action'].values[0], 
                row['Accessory'].values[0],
                row['Group'].values[0],
                row['Collage'].values[0],
                row['Human'].values[0], 
                row['Occlusion'].values[0], 
                row['Info'].values[0], 
                row['Blur'].values[0], 
                row['Pawpularity'].values[0])
            writer.write(example)
            if k%100==0: print(k,', ',end='')

In [None]:
!ls -l

# Verify TFRecords

In [None]:
# numpy and matplotlib defaults
np.set_printoptions(threshold=15, linewidth=80)
CLASSES = [0,1]

def batch_to_numpy_images_and_labels(data):
    images, labels = data
    numpy_images = images.numpy()
    numpy_labels = labels.numpy()
    #if numpy_labels.dtype == object: # binary string in this case, these are image ID strings
    #    numpy_labels = [None for _ in enumerate(numpy_images)]
    # If no labels, only image IDs, return None for labels (this is the case for test data)
    return numpy_images, numpy_labels

def title_from_label_and_target(label, correct_label):
    if correct_label is None:
        return CLASSES[label], True
    correct = (label == correct_label)
    return "{} [{}{}{}]".format(CLASSES[label], 'OK' if correct else 'NO', u"\u2192" if not correct else '',
                                CLASSES[correct_label] if not correct else ''), correct

def display_one_flower(image, title, subplot, red=False, titlesize=16):
    plt.subplot(*subplot)
    plt.axis('off')
    plt.imshow(image)
    if len(title) > 0:
        plt.title(title, fontsize=int(titlesize) if not red else int(titlesize/1.2), color='red' if red else 'black', fontdict={'verticalalignment':'center'}, pad=int(titlesize/1.5))
    return (subplot[0], subplot[1], subplot[2]+1)
    
def display_batch_of_images(databatch, predictions=None):
    """This will work with:
    display_batch_of_images(images)
    display_batch_of_images(images, predictions)
    display_batch_of_images((images, labels))
    display_batch_of_images((images, labels), predictions)
    """
    # data
    images, labels = batch_to_numpy_images_and_labels(databatch)
    if labels is None:
        labels = [None for _ in enumerate(images)]
        
    # auto-squaring: this will drop data that does not fit into square or square-ish rectangle
    rows = int(math.sqrt(len(images)))
    cols = len(images)//rows
        
    # size and spacing
    FIGSIZE = 13.0
    SPACING = 0.1
    subplot=(rows,cols,1)
    if rows < cols:
        plt.figure(figsize=(FIGSIZE,FIGSIZE/cols*rows))
    else:
        plt.figure(figsize=(FIGSIZE/rows*cols,FIGSIZE))
    
    # display
    for i, (image, label) in enumerate(zip(images[:rows*cols], labels[:rows*cols])):
        title = f'Pawpularity:{label}'
        correct = True
#         if predictions is not None:
#             title, correct = title_from_label_and_target(predictions[i], label)
        dynamic_titlesize = FIGSIZE*SPACING/max(rows,cols)*40+3 # magic formula tested to work from 1x1 to 10x10 images
        subplot = display_one_flower(image, title, subplot, not correct, titlesize=dynamic_titlesize)
    
    #layout
    plt.tight_layout()
    if label is None and predictions is None:
        plt.subplots_adjust(wspace=0, hspace=0)
    else:
        plt.subplots_adjust(wspace=SPACING, hspace=SPACING)
    plt.show()

In [None]:
def decode_image(image_data):
    image = tf.image.decode_jpeg(image_data, channels=3)
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU
    return image

def read_labeled_tfrecord(example):
    LABELED_TFREC_FORMAT = {
        'image': tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring
        'image_name': tf.io.FixedLenFeature([], tf.string),  # shape [] means single element
        'Subject Focus': tf.io.FixedLenFeature([],tf.int64),
        'Eyes': tf.io.FixedLenFeature([],tf.int64),
        'Face': tf.io.FixedLenFeature([],tf.int64),
        'Near': tf.io.FixedLenFeature([],tf.int64),
        'Action': tf.io.FixedLenFeature([],tf.int64),
        'Accessory': tf.io.FixedLenFeature([],tf.int64),
        'Group': tf.io.FixedLenFeature([],tf.int64),
        'Collage': tf.io.FixedLenFeature([],tf.int64),
        'Human': tf.io.FixedLenFeature([],tf.int64),
        'Occlusion': tf.io.FixedLenFeature([],tf.int64),
        'Info': tf.io.FixedLenFeature([],tf.int64),
        'Blur': tf.io.FixedLenFeature([],tf.int64),
        'Pawpularity': tf.io.FixedLenFeature([],tf.int64)
    }
    example = tf.io.parse_single_example(example, LABELED_TFREC_FORMAT)
    image = decode_image(example['image'])
    label = example["Pawpularity"]
    return image, label # returns a dataset of (image, label) pairs

def load_dataset(filenames, labeled=True, ordered=False):
    # Read from TFRecords. For optimal performance, reading from multiple files at once and
    # disregarding data order. Order does not matter since we will be shuffling the data anyway.

    ignore_order = tf.data.Options()
    if not ordered:
        ignore_order.experimental_deterministic = False # disable order, increase speed

    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files
    dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order
    dataset = dataset.map(read_labeled_tfrecord)
    # returns a dataset of (image, label) pairs if labeled=True or (image, id) pairs if labeled=False
    return dataset

def get_training_dataset():
    dataset = load_dataset(TRAINING_FILENAMES, labeled=True)
    dataset = dataset.repeat() # the training dataset must repeat for several epochs
    dataset = dataset.shuffle(512)
    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size)
    return dataset

def count_data_items(filenames):
    # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

In [None]:
# INITIALIZE VARIABLES
AUTO = tf.data.experimental.AUTOTUNE
TRAINING_FILENAMES = tf.io.gfile.glob('train*.tfrec')
print('There are %i train images'%count_data_items(TRAINING_FILENAMES))

In [None]:
# DISPLAY TRAIN IMAGES
training_dataset = get_training_dataset()
training_dataset = training_dataset.unbatch().batch(20)
train_batch = iter(training_dataset)

display_batch_of_images(next(train_batch))
