## Simple Neural Nets with tensorflow

In [1]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm_notebook as tqdm

%load_ext autoreload
%autoreload
%config Completer.use_jedi=False

In [2]:
tf.executing_eagerly()

True

## Read in data

Using MNIST

In [3]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

In [4]:
X_train.shape, y_train.shape

((60000, 28, 28), (60000,))

Convert numpy tensors to tensorflow tensors and create batches

In [43]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

BATCH_SIZE = 128
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

## Train with non customizable loop

Non customizable training loop

In [99]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['sparse_categorical_accuracy'])

In [100]:
%%time
model.fit(train_dataset, epochs=50)

Train for 469 steps
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50
CPU times: user 1min 22s, sys: 10.2 s, total: 1min 32s
Wall time: 33.2 s


<tensorflow.python.keras.callbacks.History at 0x15deb2400>

In [101]:
model.evaluate(test_dataset)



[1.5001950399785102, 0.9577]

Customizable training loop

In [109]:
tf.random.set_seed(42)

In [110]:
# Define simple neural net
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

# Define optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Define loss function
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

epoch_loss_avg = tf.keras.metrics.Mean()
epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
test_epoch_loss_avg = tf.keras.metrics.Mean()
test_epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

In [111]:
@tf.function
def train_loop(x, y):

    # Calculate gradients
    with tf.GradientTape() as t:
        # training=training is needed only if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(x, training=True)
        loss = loss_func(y, predictions)

    grads = t.gradient(loss, model.trainable_variables)

    # Optimize the model
    optimizer.apply_gradients(zip(grads, model.trainable_variables))
    
    # Track progress
    epoch_loss_avg(loss)

    # Compare predicted label to actual
    epoch_accuracy.update_state(y, predictions)
    
#     return loss, predictions

In [112]:
%%time
train_loss_results = []
train_accuracy_results = []
epochs = 50
n_batches = len(list(train_dataset))

for epoch in tqdm(np.arange(epochs)):
    
    for x, y in tqdm(train_dataset, total=n_batches, leave=False):
        train_loop(x, y)

    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())
    
    # Test
    for (x_valid, y_valid) in test_dataset:
        preds_test = model(x_valid, training=False)
        test_loss = loss_func(y_valid, preds_test)
        test_epoch_loss_avg(test_loss)
        test_epoch_accuracy.update_state(y_valid, preds_test)
    
    print(f"Epoch {epoch:03d}: train_loss: {epoch_loss_avg.result():.3f}, "
          f"test_loss: {test_epoch_loss_avg.result():.3f} "
          f"Accuracy: {epoch_accuracy.result():.3f}%"
          f"Test accuracy={test_epoch_accuracy.result():.3f}"
    )
    
    # Clear the current state of the metrics
    epoch_loss_avg.reset_states()
    epoch_accuracy.reset_states()
    test_epoch_loss_avg.reset_states()
    test_epoch_accuracy.reset_states()
    # valid_loss.reset_states(), valid_acc.reset_states()
    
    

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 000: train_loss: 1.786, test_loss: 1.637 Accuracy: 0.611%Test accuracy=0.787


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 001: train_loss: 1.621, test_loss: 1.588 Accuracy: 0.790%Test accuracy=0.824


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 002: train_loss: 1.585, test_loss: 1.569 Accuracy: 0.833%Test accuracy=0.845


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 003: train_loss: 1.571, test_loss: 1.557 Accuracy: 0.852%Test accuracy=0.887


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 004: train_loss: 1.559, test_loss: 1.550 Accuracy: 0.876%Test accuracy=0.883


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 005: train_loss: 1.552, test_loss: 1.541 Accuracy: 0.880%Test accuracy=0.897


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 006: train_loss: 1.545, test_loss: 1.542 Accuracy: 0.886%Test accuracy=0.904


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 007: train_loss: 1.540, test_loss: 1.538 Accuracy: 0.898%Test accuracy=0.897


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 008: train_loss: 1.537, test_loss: 1.536 Accuracy: 0.896%Test accuracy=0.896


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 009: train_loss: 1.533, test_loss: 1.530 Accuracy: 0.900%Test accuracy=0.911


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 010: train_loss: 1.529, test_loss: 1.535 Accuracy: 0.905%Test accuracy=0.912


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 011: train_loss: 1.526, test_loss: 1.528 Accuracy: 0.909%Test accuracy=0.912


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 012: train_loss: 1.523, test_loss: 1.526 Accuracy: 0.913%Test accuracy=0.917


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 013: train_loss: 1.521, test_loss: 1.522 Accuracy: 0.920%Test accuracy=0.925


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 014: train_loss: 1.519, test_loss: 1.520 Accuracy: 0.922%Test accuracy=0.921


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 015: train_loss: 1.517, test_loss: 1.523 Accuracy: 0.925%Test accuracy=0.927


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 016: train_loss: 1.516, test_loss: 1.518 Accuracy: 0.925%Test accuracy=0.923


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 017: train_loss: 1.514, test_loss: 1.517 Accuracy: 0.927%Test accuracy=0.923


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 018: train_loss: 1.512, test_loss: 1.514 Accuracy: 0.930%Test accuracy=0.926


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 019: train_loss: 1.511, test_loss: 1.514 Accuracy: 0.931%Test accuracy=0.925


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 020: train_loss: 1.509, test_loss: 1.514 Accuracy: 0.932%Test accuracy=0.928


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 021: train_loss: 1.509, test_loss: 1.512 Accuracy: 0.935%Test accuracy=0.935


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 022: train_loss: 1.507, test_loss: 1.515 Accuracy: 0.937%Test accuracy=0.936


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 023: train_loss: 1.507, test_loss: 1.511 Accuracy: 0.938%Test accuracy=0.934


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 024: train_loss: 1.505, test_loss: 1.508 Accuracy: 0.940%Test accuracy=0.940


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 025: train_loss: 1.505, test_loss: 1.508 Accuracy: 0.941%Test accuracy=0.938


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 026: train_loss: 1.503, test_loss: 1.510 Accuracy: 0.941%Test accuracy=0.943


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 027: train_loss: 1.502, test_loss: 1.509 Accuracy: 0.943%Test accuracy=0.944


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 028: train_loss: 1.502, test_loss: 1.507 Accuracy: 0.944%Test accuracy=0.945


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 029: train_loss: 1.500, test_loss: 1.506 Accuracy: 0.947%Test accuracy=0.940


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 030: train_loss: 1.500, test_loss: 1.506 Accuracy: 0.948%Test accuracy=0.945


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 031: train_loss: 1.499, test_loss: 1.506 Accuracy: 0.949%Test accuracy=0.944


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 032: train_loss: 1.499, test_loss: 1.508 Accuracy: 0.949%Test accuracy=0.946


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 033: train_loss: 1.498, test_loss: 1.507 Accuracy: 0.949%Test accuracy=0.945


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 034: train_loss: 1.498, test_loss: 1.504 Accuracy: 0.949%Test accuracy=0.945


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 035: train_loss: 1.497, test_loss: 1.504 Accuracy: 0.951%Test accuracy=0.946


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 036: train_loss: 1.497, test_loss: 1.505 Accuracy: 0.951%Test accuracy=0.946


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 037: train_loss: 1.495, test_loss: 1.503 Accuracy: 0.953%Test accuracy=0.946


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 038: train_loss: 1.495, test_loss: 1.505 Accuracy: 0.954%Test accuracy=0.948


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 039: train_loss: 1.494, test_loss: 1.503 Accuracy: 0.955%Test accuracy=0.952


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 040: train_loss: 1.493, test_loss: 1.505 Accuracy: 0.956%Test accuracy=0.944


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 041: train_loss: 1.494, test_loss: 1.503 Accuracy: 0.955%Test accuracy=0.942


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 042: train_loss: 1.493, test_loss: 1.506 Accuracy: 0.956%Test accuracy=0.948


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 043: train_loss: 1.493, test_loss: 1.502 Accuracy: 0.957%Test accuracy=0.948


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 044: train_loss: 1.492, test_loss: 1.502 Accuracy: 0.957%Test accuracy=0.951


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 045: train_loss: 1.492, test_loss: 1.504 Accuracy: 0.956%Test accuracy=0.949


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 046: train_loss: 1.491, test_loss: 1.502 Accuracy: 0.958%Test accuracy=0.950


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 047: train_loss: 1.491, test_loss: 1.504 Accuracy: 0.958%Test accuracy=0.951


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 048: train_loss: 1.490, test_loss: 1.501 Accuracy: 0.959%Test accuracy=0.951


HBox(children=(IntProgress(value=0, max=469), HTML(value='')))

Epoch 049: train_loss: 1.490, test_loss: 1.502 Accuracy: 0.958%Test accuracy=0.953

CPU times: user 1min 16s, sys: 8.02 s, total: 1min 24s
Wall time: 39.2 s
