In [23]:
import tensorflow as tf

import numpy as np
import IPython.display as display
import tqdm
import glob

In [2]:
# The following functions can be used to convert a value to a type compatible
# with tf.train.Example.

def _bytes_feature(value):
  """Returns a bytes_list from a string / byte."""
  if isinstance(value, type(tf.constant(0))):
    value = value.numpy() # BytesList won't unpack a string from an EagerTensor.
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

def _float_feature(value):
  """Returns a float_list from a float / double."""
  return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))

def _int64_feature(value):
  """Returns an int64_list from a bool / enum / int / uint."""
  return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def serialize_array(array):
  array = tf.io.serialize_tensor(array)
  return array

In [10]:
def parse_single_image(image, label):
  
  #define the dictionary -- the structure -- of our single example
  data = {
        'height' : _int64_feature(image.shape[0]),
        'width' : _int64_feature(image.shape[1]),
        'depth' : _int64_feature(image.shape[2]),
        'raw_image' : _bytes_feature(serialize_array(image)),
        'label' : _int64_feature(label)
    }
  #create an Example, wrapping the single features
  return tf.train.Example(features=tf.train.Features(feature=data))

# Creating Large Dataset
Consider sharding our data across multiple such files.

In [3]:
image_large_shape = (400,750,3)
number_of_images_large = 500 #constraining to 500 files here, to not outgrow RAM capacities

images_large = np.random.randint(low=0, high=256, size=(number_of_images_large, *image_large_shape), dtype=np.int16)

In [4]:
labels_large = np.random.randint(low=0, high=5, size=(number_of_images_large, 1))

In [18]:

def write_images_to_tfr_long(images, labels, filename:str="large_images", max_files:int=10, out_dir:str="/content/"):

  #determine the number of shards (single TFRecord files) we need:
  splits = (len(images)//max_files) + 1 #determine how many tfr shards are needed
  if len(images)%max_files == 0:
    splits-=1
  print(f"\nUsing {splits} shard(s) for {len(images)} files, with up to {max_files} samples per shard")

  file_count = 0

  for i in tqdm.tqdm(range(splits)):
    current_shard_name = "{}{}_{}{}.tfrecords".format(out_dir, i+1, splits, filename)
    writer = tf.io.TFRecordWriter(current_shard_name)

    current_shard_count = 0
    while current_shard_count < max_files: #as long as our shard is not full
      #get the index of the file that we want to parse now
      index = i*max_files+current_shard_count
      if index == len(images): #when we have consumed the whole data, preempt generation
        break

      current_image = images[index]
      current_label = labels[index]

        
      out = parse_single_image(image=current_image, label=current_label)
      
      writer.write(out.SerializeToString())
      current_shard_count+=1
      file_count += 1

    writer.close()
  print(f"\nWrote {file_count} elements to TFRecord")
  # return file_count

In [19]:
write_images_to_tfr_long(images_large, labels_large, max_files=30)


Using 17 shard(s) for 500 files, with up to 30 samples per shard


100%|██████████| 17/17 [00:03<00:00,  4.70it/s]


Wrote 500 elements to TFRecord





In [20]:
def parse_tfr_element(element):
  #use the same structure as above; it's kinda an outline of the structure we now want to create
  data = {
      'height': tf.io.FixedLenFeature([], tf.int64),
      'width':tf.io.FixedLenFeature([], tf.int64),
      'label':tf.io.FixedLenFeature([], tf.int64),
      'raw_image' : tf.io.FixedLenFeature([], tf.string),
      'depth':tf.io.FixedLenFeature([], tf.int64),
    }

    
  content = tf.io.parse_single_example(element, data)
  
  height = content['height']
  width = content['width']
  depth = content['depth']
  label = content['label']
  raw_image = content['raw_image']
  
  
  #get our 'feature'-- our image -- and reshape it appropriately
  feature = tf.io.parse_tensor(raw_image, out_type=tf.int16)
  feature = tf.reshape(feature, shape=[height,width,depth])
  return (feature, label)

In [21]:
def get_dataset_large(tfr_dir:str="/content/", pattern:str="*large_images.tfrecords"):
    files = glob.glob(tfr_dir+pattern, recursive=False)

    #create the dataset
    dataset = tf.data.TFRecordDataset(files)

    #pass every single feature through our mapping function
    dataset = dataset.map(
        parse_tfr_element
    )
    
    return dataset

In [24]:

dataset_large = get_dataset_large()

for sample in dataset_large.take(1):
  print(sample[0].shape)
  print(sample[1].shape)

(400, 750, 3)
()
