In [1]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_datasets as tfds 

In [2]:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist', # dataset name in tensorflow_dataset_catalog
    split=['train', 'test'], # can also have validation
    shuffle_files=True, # multiple tf.records of 1000
                        # even though these every 1000 intenally will be shuffled
                        # we also want ordering of files to shuffle to not see same
                        # order of files multiple times
    as_supervised=True,
    with_info=True,

)


local data directory. If you'd instead prefer to read directly from our public
GCS bucket (recommended if you're running on GCP), you can instead set
data_dir=gs://tfds-data/datasets.



[1mDownloading and preparing dataset mnist/3.0.0 (download: 11.06 MiB, generated: Unknown size, total: 11.06 MiB) to /root/tensorflow_datasets/mnist/3.0.0...[0m


HBox(children=(FloatProgress(value=0.0, description='Dl Completed...', max=4.0, style=ProgressStyle(descriptio…



[1mDataset mnist downloaded and prepared to /root/tensorflow_datasets/mnist/3.0.0. Subsequent calls will reuse this data.[0m


In [3]:
def normalize_img(image, label):
  # normalize images
  return tf.cast(image, tf.float32)/255.0, label

AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_train = ds_train.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(32)
ds_train = ds_train.prefetch(AUTOTUNE)

AUTOTUNE = tf.data.experimental.AUTOTUNE
ds_test = ds_test.map(normalize_img, num_parallel_calls=AUTOTUNE)
ds_test = ds_test.batch(32)
ds_test = ds_test.prefetch(AUTOTUNE)

In [20]:
model = keras.Sequential(
    [
     layers.Input((28,28,1)),
     layers.Conv2D(32, 3, activation='relu'),
     layers.Flatten(),
     layers.Dense(10)
    ]
)

## Some Predefined callbacks <br>
If there is validation set need to tell <br>
 train_loss, val_loss, test_loss OR <br>
 train_accuracy, val_accuracy, test_accuracy

In [10]:
# Save Model callback
save_callback = keras.callbacks.ModelCheckpoint(
    'checkpoint/', save_weights_only=True,
    monitor='accuracy', save_best_only=False,
)


# Learning rate

def scheduler(epoch, lr):
  if epoch < 2:
    return lr
  else:
    return lr * 0.99

lr_scheduler = keras.callbacks.LearningRateScheduler(scheduler, verbose=1)

##Custom callbacks

In [17]:
class CustomCallback(keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    print(logs.keys())
    if logs.get('accuracy') > 0.90:
      print("accuracy over 90%, quitting training")
      self.model.stop_training=True
    

## Training

In [21]:
model.compile(
    optimizer=keras.optimizers.Adam(0.01),
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

model.fit(ds_train, epochs=10, verbose=2, 
          callbacks=[save_callback, lr_scheduler, CustomCallback()])


Epoch 00001: LearningRateScheduler reducing learning rate to 0.009999999776482582.
Epoch 1/10
dict_keys(['loss', 'accuracy', 'lr'])
accuracy over 90%, quitting training
1875/1875 - 5s - loss: 0.1401 - accuracy: 0.9579


<tensorflow.python.keras.callbacks.History at 0x7efff9c54e10>