Good to run this to ensure you are using TF2.x

In [None]:
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

In [None]:
# At time of creation, tfds was at version 2.0 which had
# some bugs in some common datasets. Advise to update to
# tfds 2.1.0 like this:
pip install tensorflow_datasets==2.1.0

# Create the Model Architecture


In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_addons as tfa
import numpy as np
import multiprocessing

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Conv2D(16, (3,3), activation='relu',
                                      input_shape=(300, 300, 3)),
    tf.keras.layers.MaxPooling2D(2, 2),
    tf.keras.layers.Conv2D(32, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Conv2D(64, (3,3), activation='relu'),
    tf.keras.layers.MaxPooling2D(2,2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(512, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid')
  ])


  model.compile(optimizer='Adam', loss='binary_crossentropy', metrics=['accuracy'])
  return model

# EXTRACT


In [None]:
train_data = tfds.load('cats_vs_dogs', split='train', with_info=True)


In [None]:
file_pattern = f'/root/tensorflow_datasets/cats_vs_dogs/4.0.0/cats_vs_dogs-train.tfrecord*'
files = tf.data.Dataset.list_files(file_pattern)

In [None]:
train_dataset = files.interleave(tf.data.TFRecordDataset, cycle_length=4, num_parallel_calls=tf.data.experimental.AUTOTUNE)


# TRANSFORM


In [None]:
def read_tfrecord(serialized_example):
  feature_description={
      "image": tf.io.FixedLenFeature((), tf.string, ""),
      "label": tf.io.FixedLenFeature((), tf.int64, -1),
  }
  example = tf.io.parse_single_example(serialized_example, feature_description)
  image = tf.io.decode_jpeg(example['image'], channels=3)
  image = tf.cast(image, tf.float32)
  image = image / 255
  image = tf.image.resize(image, (300,300))
  return image, example['label']

In [None]:
cores = multiprocessing.cpu_count()
print(cores)

train_dataset = train_dataset.map(read_tfrecord, num_parallel_calls=cores)

#train_dataset = train_dataset.cache()

# LOAD

In [None]:
train_dataset = train_dataset.shuffle(1024).batch(32)

train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

In [None]:
model = create_model()
model.fit(train_dataset, epochs=10, verbose=1)