<a href="https://colab.research.google.com/github/phrasenmaeher/TFRecord_walkthrough/blob/main/practical_guide_to_tfr.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# A practical introduction to TFRecords

This notebook explains how to create TFRecord files for various data types. It contains the code for the post at [TDS](https://towardsdatascience.com/a-practical-guide-to-tfrecords-584536bc786c).

## Imports and helper functions


Let's start by importing our required packages, TensorFlow and Numpy


In [None]:
import tensorflow as tf
import numpy as np
import librosa

Next we need to define four small helper functions that hold the features that we'll store in our TFRecord files

In [None]:
def _bytes_feature(value):
    """Returns a bytes_list from a string / byte."""
    if isinstance(value, type(tf.constant(0))): # if value ist tensor
        value = value.numpy() # get value of tensor
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a floast_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_array(array):
  array = tf.io.serialize_tensor(array)
  return array

## Image data

### A couple of images

This section assumes that you want to write image data to your disk. Rather than dowloading some image datasets, we'll create reasonably shaped numpy arrays

In [None]:
image_small_shape = (250,250,3)
number_of_images_small = 100

In [None]:
images_small = np.random.randint(low=0, high=256, size=(number_of_images_small, *image_small_shape), dtype=np.int16)
print(images_small.shape)

(100, 250, 250, 3)


Now we create some random labels, and inspect them

In [None]:
labels_small = np.random.randint(low=0, high=5, size=(number_of_images_small, 1))
print(labels_small.shape)
print(labels_small[:10])

(100, 1)
[[2]
 [4]
 [3]
 [3]
 [2]
 [4]
 [2]
 [3]
 [3]
 [0]]


Define our function to write the image data

In [None]:
def parse_single_image(image, label):
  
  #define the dictionary -- the structure -- of our single example
  data = {
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        'label' : _int64_feature(label)
    }
  
  #create an Example, wrapping the single features
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

In [None]:
def write_images_to_tfr_short(images, labels, filename:str="images"):
  filename= filename+".tfrecords"
  writer = tf.io.TFRecordWriter(filename) #create a writer that'll store our data to disk
  count = 0

  for index in range(len(images)):

    #get the data we want to write
    current_image = images[index] 
    current_label = labels[index]

    #define the dictionary -- the structure -- of our single example
    out = parse_single_image(image=current_image, label=current_label)
    writer.write(out.SerializeToString())
    count += 1

  writer.close()
  print(f"Wrote {count} elements to TFRecord")
  return count

In [None]:
count = write_images_to_tfr_short(images_small, labels_small, filename="small_images")

In [None]:
def parse_tfr_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create
  data = {
      'height': tf.io.FixedLenFeature([], tf.int64),
      'width':tf.io.FixedLenFeature([], tf.int64),
      'label':tf.io.FixedLenFeature([], tf.int64),
      'raw_image' : tf.io.FixedLenFeature([], tf.string),
      'depth':tf.io.FixedLenFeature([], tf.int64),
    }

    
  content = tf.io.parse_single_example(element, data)
  
  height = content['height']
  width = content['width']
  depth = content['depth']
  label = content['label']
  raw_image = content['raw_image']
  
  
  #get our 'feature'-- our image -- and reshape it appropriately
  feature = tf.io.parse_tensor(raw_image, out_type=tf.int16)
  feature = tf.reshape(feature, shape=[height,width,depth])
  return (feature, label)


In [None]:
def get_dataset_small(filename):
  #create the dataset
  dataset = tf.data.TFRecordDataset(filename)

  #pass every single feature through our mapping function
  dataset = dataset.map(
      parse_tfr_element
  )
    
  return dataset

In [None]:
dataset_small = get_dataset_small("/content/small_images.tfrecords")

In [None]:
for sample in dataset_small.take(1):
  print(sample[0].shape)
  print(sample[1].shape)

(250, 250, 3)
()


### More than a couple of images

Now, what would we do if we had not 100, but 50000 images, of larger shape? They do fit into a single file--but then we would have one large file. In their docs, TF encourages to shard the data across multiple files to enable parallel I/O. Secondly, one shard should be larger than 100 MB. Let's see how we can do this!

In [None]:
import tqdm
import glob

In [None]:
image_large_shape = (400,750,3)
number_of_images_large = 500 #constraining to 500 files here, to not outgrow RAM capacities

In [None]:
images_large = np.random.randint(low=0, high=256, size=(number_of_images_large, *image_large_shape), dtype=np.int16)
print(images_large.shape)

(500, 400, 750, 3)


Now we create some random labels, and inspect them

In [None]:
labels_large = np.random.randint(low=0, high=5, size=(number_of_images_large, 1))
print(labels_large.shape)
print(labels_large[:10])

(500, 1)
[[3]
 [1]
 [3]
 [1]
 [0]
 [1]
 [3]
 [4]
 [3]
 [1]]


In [None]:
def write_images_to_tfr_long(images, labels, filename:str="large_images", max_files:int=10, out_dir:str="/content/"):

  #determine the number of shards (single TFRecord files) we need:
  splits = (len(images)//max_files) + 1 #determine how many tfr shards are needed
  if len(images)%max_files == 0:
    splits-=1
  print(f"\nUsing {splits} shard(s) for {len(images)} files, with up to {max_files} samples per shard")

  file_count = 0

  for i in tqdm.tqdm(range(splits)):
    current_shard_name = "{}{}_{}{}.tfrecords".format(out_dir, i+1, splits, filename)
    writer = tf.io.TFRecordWriter(current_shard_name)

    current_shard_count = 0
    while current_shard_count < max_files: #as long as our shard is not full
      #get the index of the file that we want to parse now
      index = i*max_files+current_shard_count
      if index == len(images): #when we have consumed the whole data, preempt generation
        break
      
      current_image = images[index]
      current_label = labels[index]

      #create the required Example representation
      out = parse_single_image(image=current_image, label=current_label)
    
      writer.write(out.SerializeToString())
      current_shard_count+=1
      file_count += 1

    writer.close()
  print(f"\nWrote {file_count} elements to TFRecord")
  return file_count

In [None]:
write_images_to_tfr_long(images_large, labels_large, max_files=30)

In [None]:
def get_dataset_large(tfr_dir:str="/content/", pattern:str="*large_images.tfrecords"):
    files = glob.glob(tfr_dir+pattern, recursive=False)

    #create the dataset
    dataset = tf.data.TFRecordDataset(files)

    #pass every single feature through our mapping function
    dataset = dataset.map(
        parse_tfr_element
    )
    
    return dataset

In [None]:
dataset_large = get_dataset_large()

In [None]:
for sample in dataset_large.take(1):
  print(sample[0].shape)
  print(sample[1].shape)

## Audio data

Let's construct an artificial dataset first:


In [None]:
import librosa

The audio samples are of different length. But that's not of concern, TFRecords naturally support this case.

In [None]:
def create_dummy_audio_dataset():
  files = []
  labels = []

  for i in range(100):
    if i %2==0:
      filename = librosa.ex('fishin')
      labels.append(0)
    if i %3==0:
      filename = librosa.ex('brahms')
      labels.append(1)
    if i %5==0:
      filename = librosa.ex('nutcracker')
      labels.append(2)
    if i %7==0:
      filename = librosa.ex('trumpet')
      labels.append(3)
    else:
      filename = librosa.ex('vibeace')
      labels.append(4)
    
    y, sr = librosa.load(filename)
    files.append([y, sr])
  return files, labels

In [None]:
def parse_single_audio_file(audio, label):

  data = {
        'sr' : _int64_feature(audio[1]),
        'len' : _int64_feature(len(audio[0])),
        'y' : _bytes_feature(serialize_array(audio[0])),
        'label' : _int64_feature(label)
    }
  
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

In [None]:
def write_audio_to_tfr(audios, labels, filename:str="audio"):
  filename= filename+".tfrecords"
  writer = tf.io.TFRecordWriter(filename) #create a writer that'll store our data to disk
  count = 0

  for index in range(len(audios)):

    #get the data we want to write
    current_audio = audios[index] 
    current_label = labels[index]

    #define the dictionary -- the structure -- of our single example
    out = parse_single_audio_file(audio=current_audio, label=current_label)
    writer.write(out.SerializeToString())
    count += 1

  writer.close()
  print(f"Wrote {count} elements to TFRecord")
  return count

In [None]:
audios, labels = create_dummy_audio_dataset()

In [None]:
write_audio_to_tfr(audios, labels)

In [None]:
def parse_tfr_audio_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create

  data = {
      'sr': tf.io.FixedLenFeature([], tf.int64),
      'len':tf.io.FixedLenFeature([], tf.int64),
      'y' : tf.io.FixedLenFeature([], tf.string),
      'label':tf.io.FixedLenFeature([], tf.int64),
      
    }
  
  content = tf.io.parse_single_example(element, data)
  
  sr = content['sr']
  len = content['len']
  y = content['y']
  label = content['label']
  
  
  #get our 'feature'-- our image -- and reshape it appropriately
  feature = tf.io.parse_tensor(y, out_type=tf.float32)
  feature = tf.reshape(feature, shape=[len])
  return (feature, label)


In [None]:
def get_audio_dataset(filename):
  #create the dataset
  dataset = tf.data.TFRecordDataset(filename)

  #pass every single feature through our mapping function
  dataset = dataset.map(
      parse_tfr_audio_element
  )
    
  return dataset

In [None]:
dataset_audio = get_audio_dataset("/content/audio.tfrecords")

In [None]:
for sample in dataset_audio.take(1):
  print(sample[0].shape) #the audio data
  print(sample[1]) #the label

## Text data

So far we have worked with numerical data only: Both the images and the audio files were repesentated as float values. For the following example, let's cover the third large domain: Text data

In [None]:
def create_dummy_text_dataset(size:int=100):
  text_data = []
  labels = []

  for i in range(size):
    if i % 2 == 0:
      text = "Hey, this is a sample text. We can use many different symbols."
      label = 0
    else:
      text = "A point is exactly what the folks think of it; after Gauss."
      label = 1
    text_data.append(text)
    labels.append(label)
  
  return text_data, labels

In [None]:
def parse_single_text_data(text, label):

  data = {
        'text' : _bytes_feature(serialize_array(text)),
        'label' : _int64_feature(label)
    }
  
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

In [None]:
def write_text_to_tfr(text_data, labels, filename:str="text"):
  filename= filename+".tfrecords"
  writer = tf.io.TFRecordWriter(filename) #create a writer that'll store our text data to disk
  count = 0

  for index in range(len(text_data)):

    #get the data we want to write
    current_text = text_data[index] 
    current_label = labels[index]

    #define the dictionary -- the structure -- of our single example
    out = parse_single_text_data(text=current_text, label=current_label)
    writer.write(out.SerializeToString())
    count += 1

  writer.close()
  print(f"Wrote {count} elements to TFRecord")
  return count

In [None]:
text, labels = create_dummy_text_dataset()

In [None]:
text[:5]

['Hey, this is a sample text. We can use many different symbols.',
 'A point is exactly what the folks think of it; after Gauss.',
 'Hey, this is a sample text. We can use many different symbols.',
 'A point is exactly what the folks think of it; after Gauss.',
 'Hey, this is a sample text. We can use many different symbols.']

In [None]:
write_text_to_tfr(text_data=text, labels=labels)

In [None]:
def parse_tfr_text_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create
  
  data = {
      'text' : tf.io.FixedLenFeature([], tf.string),
      'label':tf.io.FixedLenFeature([], tf.int64),
      
    }
  
  content = tf.io.parse_single_example(element, data)
  
  text = content['text']
  label = content['label']
  
  #get our 'feature', our text data
  feature = tf.io.parse_tensor(text, out_type=tf.string)
  return (feature, label)


In [None]:
def get_text_dataset(filename):
  #create the dataset
  dataset = tf.data.TFRecordDataset(filename)

  #pass every single feature through our mapping function
  dataset = dataset.map(
      parse_tfr_text_element
  )
    
  return dataset

In [None]:
text_dataset = get_text_dataset("/content/text.tfrecords")

In [None]:
for sample in text_dataset.take(2):
  print(sample[0].numpy()) #the text data
  print(sample[1]) #the label

b'Hey, this is a sample text. We can use many different symbols.'
tf.Tensor(0, shape=(), dtype=int64)
b'A point is exactly what the folks think of it; after Gauss.'
tf.Tensor(1, shape=(), dtype=int64)


## Multiple data types

We have examined single single domains until now. Of course, there's nothing that speaks agains combining multiple domains! For the following, consider this outline:

We have multiple images:

In [None]:
images_shape = (256, 256, 3)
size = 100
images_combined = np.random.randint(low=0, high=256, size=(100, *images_shape), dtype=np.int16)
print(images_combined.shape)

(100, 256, 256, 3)


Secondly, we have a short description of each image, describing the scenery that the images shows:

In [None]:
def create_dummy_text_dataset_combined(size:int=100):
  text_data = []
  labels = []

  for i in range(size):
    if i %2==0:
      text = "This image shows a wooden bridge. It connects South Darmian with the norther parts of Frenklund."
      label = 0
    if i %3==0:
      text = "This image shows a sun flower. It's leaves are green, the petals are of strong yellow"
      label = 1
    if i %5==0:
      text = "This image shows five children playing in the sandbox. They are laughing"
      label = 2
    if i %7==0:
      text = "This image shows a house on a cliff. The house is painted in red and brown tones."
      label = 3
    else:
      text = "This image shows a horse and a zebra. They come from a CycleGAN."
      label = 4
  
    text_data.append(text)
    labels.append(label)
  
  return text_data, labels

In [None]:
text, text_labels = create_dummy_text_dataset_combined()

Lastly, we also have an auditive description of the scenery. We'll reuse the dummy audio data from above:

In [None]:
def create_dummy_audio_dataset(size:int=100):
  files = []
  labels = []

  for i in range(size):
    if i %2==0:
      filename = librosa.ex('fishin')
      labels.append(0)
    if i %3==0:
      filename = librosa.ex('brahms')
      labels.append(1)
    if i %5==0:
      filename = librosa.ex('nutcracker')
      labels.append(2)
    if i %7==0:
      filename = librosa.ex('trumpet')
      labels.append(3)
    else:
      filename = librosa.ex('vibeace')
      labels.append(4)
    
    y, sr = librosa.load(filename)
    files.append([y, sr])
  return files, labels

In [None]:
audio, audio_labels = create_dummy_audio_dataset()

Downloading file 'Karissa_Hobbs_-_Let's_Go_Fishin'.ogg' from 'https://librosa.org/data/audio/Karissa_Hobbs_-_Let's_Go_Fishin'.ogg' to '/root/.cache/librosa'.
Downloading file 'Hungarian_Dance_number_5_-_Allegro_in_F_sharp_minor_(string_orchestra).ogg' from 'https://librosa.org/data/audio/Hungarian_Dance_number_5_-_Allegro_in_F_sharp_minor_(string_orchestra).ogg' to '/root/.cache/librosa'.
Downloading file 'Kevin_MacLeod_-_P_I_Tchaikovsky_Dance_of_the_Sugar_Plum_Fairy.ogg' from 'https://librosa.org/data/audio/Kevin_MacLeod_-_P_I_Tchaikovsky_Dance_of_the_Sugar_Plum_Fairy.ogg' to '/root/.cache/librosa'.
Downloading file 'sorohanro_-_solo-trumpet-06.ogg' from 'https://librosa.org/data/audio/sorohanro_-_solo-trumpet-06.ogg' to '/root/.cache/librosa'.
Downloading file 'Kevin_MacLeod_-_Vibe_Ace.ogg' from 'https://librosa.org/data/audio/Kevin_MacLeod_-_Vibe_Ace.ogg' to '/root/.cache/librosa'.


Now, let's combine them into the TFRecord files:

In [None]:
def parse_combined_data(image, text, text_label, audio, audio_label):

  data = {
        #for the image
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        #for the text
        'text' : _bytes_feature(serialize_array(text)),
        'text_label' : _int64_feature(text_label),
        #for the audio
        'sr' : _int64_feature(audio[1]),
        'len' : _int64_feature(len(audio[0])),
        'y' : _bytes_feature(serialize_array(audio[0])),
        'audio_label' : _int64_feature(audio_label)
    }
  
  out = tf.train.Example(features=tf.train.Features(feature=data))

  return out

In [None]:
def write_combined_data_to_tfr(images, text_data, text_labels, audio_data, audio_labels, filename:str="combined"):
  filename= filename+".tfrecords"
  writer = tf.io.TFRecordWriter(filename) #create a writer that'll store our text data to disk
  count = 0

  for index in range(len(images)):
    
    #get the image data
    current_image = images[index]

    #get the text data
    current_text = text_data[index] 
    current_text_label = text_labels[index]

    #get the audio data
    current_audio = audio_data[index]
    current_audio_label = audio_labels[index]

    out = parse_combined_data(image=current_image, text=current_text, text_label=current_text_label, audio=current_audio, audio_label=current_audio_label)
    writer.write(out.SerializeToString())
    count += 1

  writer.close()
  print(f"Wrote {count} elements to TFRecord")
  return count

In [None]:
write_combined_data_to_tfr(images=images_combined, text_data=text, text_labels=text_labels, audio_data=audio, audio_labels=audio_labels)

Wrote 100 elements to TFRecord


100

In [None]:
def parse_combined_tfr_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create
  data = {
      #for the images
      'height': tf.io.FixedLenFeature([], tf.int64),
      'width':tf.io.FixedLenFeature([], tf.int64),
      'raw_image' : tf.io.FixedLenFeature([], tf.string),
      'depth':tf.io.FixedLenFeature([], tf.int64),
      #for the text
      'text' : tf.io.FixedLenFeature([], tf.string),
      'text_label':tf.io.FixedLenFeature([], tf.int64),
      #for the audio
      'sr': tf.io.FixedLenFeature([], tf.int64),
      'len':tf.io.FixedLenFeature([], tf.int64),
      'y' : tf.io.FixedLenFeature([], tf.string),
      'audio_label':tf.io.FixedLenFeature([], tf.int64),
      
    }
  
  content = tf.io.parse_single_example(element, data)

  #image data
  height = content['height']
  width = content['width']
  depth = content['depth']
  raw_image = content['raw_image']
  
  image_feature = tf.io.parse_tensor(raw_image, out_type=tf.int16)
  image_feature = tf.reshape(image_feature, shape=[height,width,depth])
  
  #audio data
  sr = content['sr']
  len = content['len']
  y = content['y']
  audio_label = content['audio_label']

  audio_feature = tf.io.parse_tensor(y, out_type=tf.float32)
  audio_feature = tf.reshape(audio_feature, shape=[len])

  
  #text data
  text = content['text']
  text_label = content['text_label']
  
  text_feature = tf.io.parse_tensor(text, out_type=tf.string)

  
  return image_feature, text_feature, text_label, audio_feature, audio_label

In [None]:
def get_combined_dataset(filename):
  #create the dataset
  dataset = tf.data.TFRecordDataset(filename)

  #pass every single feature through our mapping function
  dataset = dataset.map(
      parse_combined_tfr_element
  )
    
  return dataset

In [None]:
ds = get_combined_dataset("/content/combined.tfrecords")

In [None]:
next(iter(ds))

(<tf.Tensor: shape=(256, 256, 3), dtype=int16, numpy=
 array([[[160, 224, 213],
         [ 45, 231, 164],
         [157, 167, 117],
         ...,
         [221,  46, 247],
         [207,  30, 251],
         [127, 133, 154]],
 
        [[137, 211, 154],
         [172, 160,  55],
         [125, 171,  19],
         ...,
         [ 78,  10, 144],
         [191, 131, 125],
         [101,  32, 140]],
 
        [[182, 191,  61],
         [112, 247,  29],
         [248, 203, 166],
         ...,
         [145,  91, 130],
         [165, 108,  59],
         [  6, 125,  19]],
 
        ...,
 
        [[169, 122, 229],
         [160, 185, 109],
         [ 29, 255, 210],
         ...,
         [129,  37, 226],
         [194, 130,  64],
         [126,  32, 218]],
 
        [[193,  93, 110],
         [ 15, 130,  75],
         [122,  46,  23],
         ...,
         [ 72, 104, 223],
         [253, 149,  46],
         [123,  28,  44]],
 
        [[163, 140,  96],
         [146, 244, 244],
         [109,