<a href="https://colab.research.google.com/github/yangli2/rl_to_learn/blob/master/mnist.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%tensorflow_version 2.x
import tensorflow as tf
from google.colab import drive
import tensorflow_datasets as tfds

GOOGLE_DRIVE = '/content/gdrive'
drive.mount(GOOGLE_DRIVE)
DRIVE_DIR = '{}/My Drive'.format(GOOGLE_DRIVE)
MNIST_DIR = '{}/mnist_dataset'.format(DRIVE_DIR)
MODEL_DIR  = '{}/models'.format(DRIVE_DIR)

TensorFlow 2.x selected.
Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
mnist = tfds.image.mnist.MNIST()
mnist.download_and_prepare(download_dir=MNIST_DIR)
mnist_ds = mnist.as_dataset()

In [0]:
from tensorflow.keras.layers import Dense, Flatten, Conv2D
from tensorflow.keras import Model

class MnistModel(Model):
  def __init__(self):
    super(MnistModel, self).__init__()
    self.conv1 = Conv2D(32, 3, activation='relu')
    self.flatten = Flatten()
    self.d1 = Dense(128, activation='relu')
    self.d2 = Dense(10, activation='softmax')

  def call(self, x):
    x = self.conv1(x)
    x = self.flatten(x)
    x = self.d1(x)
    return self.d2(x)

def train_step(images, labels, loss_obj, optimizer, loss_fn,
               accuracy_fn=None):
  loss_fn = loss_fn if loss_fn is not None else tf.keras.metrics.Mean()
  with tf.GradientTape() as tape:
    predictions = model(images)
    loss = loss_obj(labels, predictions)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  return (loss_fn(loss),
          accuracy_fn(labels, predictions)
          if accuracy_fn is not None else None)

def test_step(images, labels, loss_obj, loss_fn, accuracy_fn=None):
  loss_fn = loss_fn if loss_fn is not None else tf.keras.metrics.Mean()

  predictions = model(images)
  t_loss = loss_obj(labels, predictions)

  return (loss_fn(t_loss),
          accuracy_fn(labels, predictions)
          if accuracy_fn is not None else None)

def generate_model():
  # Create an instance of the model
  model = MnistModel()
  optimizer = tf.keras.optimizers.Adam()

  loss = tf.keras.losses.SparseCategoricalCrossentropy()
  train_loss = tf.keras.metrics.Mean(name='training_loss')
  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_loss = tf.keras.metrics.Mean(name='test_loss')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

  return (model, optimizer, loss, train_loss, train_accuracy,
          test_loss, test_accuracy)


In [0]:
NUM_EPOCHS = 5
def train_model(mnist_ds, model_filename, num_epochs=NUM_EPOCHS):
  (model, optimizer, loss, train_loss, train_accu,
   test_loss, test_accu) = generate_model()
  try:
    model.load_weights(model_filename)
  except:
    for epoch in range(EPOCHS):
      for example in mnist_ds['train'].batch(32):
        train_step(tf.cast(example['image'], tf.float32) / 255.,
                   tf.cast(example['label'], tf.float32), loss_obj=loss,
                   optimizer=optimizer, loss_fn=train_loss,
                   accuracy_fn=train_accu)
    
      for test_example in mnist_ds['test'].batch(32):
        test_step(tf.cast(test_example['image'], tf.float32) / 255.,
                  tf.cast(test_example['label'], tf.float32), loss_obj=loss,
                  loss_fn=test_loss, accuracy_fn=test_accu)
    
      template = 'Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, Test Accuracy: {}'
      print(template.format(epoch+1,
                            train_loss.result(),
                            train_accu.result()*100,
                            test_loss.result(),
                            test_accu.result()*100))
    
      # Reset the metrics for the next epoch
      train_loss.reset_states()
      train_accu.reset_states()
      test_loss.reset_states()
      test_accu.reset_states()
    model.save_weights(model_filename)
  return model  

In [0]:
model = train_model(mnist_ds, '{}/mnist_vanilla.mdl'.format(MODEL_DIR))