In [193]:
import tensorflow as tf
import numpy as np
from tqdm import tqdm_notebook as tqdm

from dyn_fed.data.mnist import MNist
import dyn_fed as df

%load_ext autoreload
%autoreload
%config Completer.use_jedi=False

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
tf.executing_eagerly()

True

In [6]:
X_train, y_train, X_test, y_test = df.data.mnist.load_data()

In [7]:
train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

In [10]:
BATCH_SIZE = 64
SHUFFLE_BUFFER_SIZE = 100

train_dataset = train_dataset.shuffle(SHUFFLE_BUFFER_SIZE).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)

Non customizable training loop way

In [226]:
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    # tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

model.compile(optimizer=tf.keras.optimizers.SGD(learning_rate=0.01),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['sparse_categorical_accuracy'])

In [227]:
%%time
model.fit(train_dataset, epochs=10)

Train for 938 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
CPU times: user 14 s, sys: 2.09 s, total: 16.1 s
Wall time: 8.86 s


<tensorflow.python.keras.callbacks.History at 0x1461e10f0>

In [228]:
model.evaluate(test_dataset)



[1.688275268882703, 0.8612]

Customizable training loop

In [229]:
# Define logistic regression model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    # tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation="sigmoid")
])

# Define optimizer
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

# Define loss function
loss_func = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

epoch_loss_avg = tf.keras.metrics.Mean()
epoch_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

In [230]:
@tf.function
def train_loop(x, y):

    # Calculate gradients
    with tf.GradientTape() as t:
        # training=training is needed only if there are layers with different
        # behavior during training versus inference (e.g. Dropout).
        predictions = model(x, training=True)
        loss = loss_func(y, predictions)

    grads = t.gradient(loss, model.trainable_variables)

    # Optimize the model
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Track progress
    epoch_loss_avg(loss)

    # Compare predicted label to actual
    epoch_accuracy.update_state(y, predictions)

In [231]:
%%time
train_loss_results = []
train_accuracy_results = []
epochs = 10
n_batches = len(list(train_dataset))

for epoch in tqdm(np.arange(epochs)):
    
    for x, y in tqdm(train_dataset, total=n_batches, leave=False):
        train_loop(x, y)
    # End epoch
    train_loss_results.append(epoch_loss_avg.result())
    train_accuracy_results.append(epoch_accuracy.result())
    
    print(
        "Epoch {:03d}: Loss: {:.3f}, Accuracy: {:.3%}".format(
            epoch,
            epoch_loss_avg.result(),
            epoch_accuracy.result()
        )
    )
    
    # Clear the current state of the metrics
    epoch_loss_avg.reset_states()
    epoch_accuracy.reset_states()
    # valid_loss.reset_states(), valid_acc.reset_states()
    
    

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 000: Loss: 2.097, Accuracy: 59.857%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 001: Loss: 1.901, Accuracy: 79.412%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 002: Loss: 1.826, Accuracy: 81.612%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 003: Loss: 1.786, Accuracy: 82.668%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 004: Loss: 1.760, Accuracy: 83.310%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 005: Loss: 1.742, Accuracy: 83.820%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 006: Loss: 1.728, Accuracy: 84.243%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 007: Loss: 1.717, Accuracy: 84.575%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 008: Loss: 1.708, Accuracy: 84.867%


HBox(children=(IntProgress(value=0, max=938), HTML(value='')))

Epoch 009: Loss: 1.700, Accuracy: 85.103%

CPU times: user 8.38 s, sys: 1.54 s, total: 9.92 s
Wall time: 7.26 s
