In [2]:
%tensorflow_version 2.x

TensorFlow 2.x selected.


In [0]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm_notebook

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [5]:
x_train.shape

(60000, 28, 28)

In [0]:
x_train = np.expand_dims(x_train, 3)
x_test = np.expand_dims(x_test, 3)

In [7]:
x_train.shape

(60000, 28, 28, 1)

In [0]:
x_train = x_train / 255.
x_test = x_test / 255.

In [0]:
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.batch(128)
train_dataset = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)

In [0]:
valid_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
valid_dataset = valid_dataset.batch(128)
valid_dataset = valid_dataset.prefetch(tf.data.experimental.AUTOTUNE)

In [11]:
train_dataset

<PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float64, tf.uint8)>

In [12]:
valid_dataset

<PrefetchDataset shapes: ((None, 28, 28, 1), (None,)), types: (tf.float64, tf.uint8)>

In [0]:
simple_model = tf.keras.Sequential([tf.keras.layers.Input(shape=(28,28,1)),
                                    tf.keras.layers.Conv2D(256, 7, 1, 'same', activation='relu'),
                                    tf.keras.layers.MaxPool2D(),
                                    tf.keras.layers.Conv2D(256, 5, 1, 'same', activation='relu'),
                                    tf.keras.layers.MaxPool2D(),
                                    tf.keras.layers.Conv2D(256, 5, 1, 'same', activation='relu'),
                                    tf.keras.layers.GlobalMaxPool2D(),
                                    tf.keras.layers.Dense(10, activation='softmax')])

In [14]:
simple_model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              (None, 28, 28, 256)       12800     
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 14, 14, 256)       0         
_________________________________________________________________
conv2d_1 (Conv2D)            (None, 14, 14, 256)       1638656   
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 7, 7, 256)         0         
_________________________________________________________________
conv2d_2 (Conv2D)            (None, 7, 7, 256)         1638656   
_________________________________________________________________
global_max_pooling2d (Global (None, 256)               0         
_________________________________________________________________
dense (Dense)                (None, 10)                2

In [0]:
checkpoint_dir = '/content'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

In [16]:
checkpoint_prefix

'/content/ckpt'

In [0]:
my_opt = tf.keras.optimizers.Nadam(learning_rate=0.008, clipnorm=1.)

In [0]:
checkpoint = tf.train.Checkpoint(model=simple_model,
                                 optimizer=my_opt)

In [0]:
manager = tf.train.CheckpointManager(checkpoint, checkpoint_dir, 3)

In [0]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc = tf.keras.metrics.SparseCategoricalAccuracy()
loss_metric = tf.keras.metrics.Mean()

@tf.function
def train_step(img, label):
  with tf.GradientTape() as tape:  
    pred = simple_model(img)
    loss = loss_object(label, pred)
  variables = simple_model.trainable_variables
  gradient = tape.gradient(loss, variables)
  my_opt.apply_gradients(zip(gradient, variables))

  train_acc(label, pred)
  loss_metric(loss)

In [21]:
EPOCH = 10

for e in range(10):
  for n, (img, label) in tqdm_notebook(enumerate(train_dataset)):
    loss = train_step(img, label)


  print('Epoch: ', e)
  print(train_acc.result())
  print(loss_metric.result())
  print('save ckpt...')
  manager.save()


Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(IntProgress(value=1, bar_style='info', max=1), HTML(value='')))



To change all layers to have dtype float64 by default, call `tf.keras.backend.set_floatx('float64')`. To change just this layer, pass dtype='float64' to the layer constructor. If you are the author of this layer, you can disable autocasting by passing autocast=False to the base Layer constructor.



KeyboardInterrupt: ignored