In [1]:
import tensorflow as tf
import numpy as np
import os

print(tf.__version__)

2.3.1


In [2]:
a = np.ones((5,5))
a.reshape(5, 5, -1).shape

(5, 5, 1)

In [3]:
fashion_mnist = tf.keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

# for conv2d input shape
train_images = train_images[..., None] # (60000, 28, 28, 1)
test_images = test_images[..., None] # (10000, 28, 28, 1)

train_images = train_images / np.float32(255) # rescale
test_images = test_images / np.float32(255)

In [4]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)


In [5]:
print('Num of devices: {}'.format(strategy.num_replicas_in_sync))

Num of devices: 1


In [6]:
BUFFER_SIZE = len(train_images)

BATCH_SIZE_PER_REPLICA = 64
GLOBAL_BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

EPOCHS = 10

In [7]:
with strategy.scope():
    train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(BUFFER_SIZE).batch(GLOBAL_BATCH_SIZE)
    train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
    
    test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

In [8]:
def create_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32,3,activation = 'relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Conv2D(64,3, activation = 'relu'),
        tf.keras.layers.MaxPooling2D(),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation = 'relu'),
        tf.keras.layers.Dense(10, activation = 'softmax')
    ])
    
    return model

In [9]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

In [10]:
with strategy.scope():
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        reduction = tf.keras.losses.Reduction.NONE
    )
    def compute_loss(labels, predictions):
        per_example_loss = loss_object(labels, predictions)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size = GLOBAL_BATCH_SIZE)

In [11]:
with strategy.scope():
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    
    train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='train_accuracy')
    test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
        name='test_accuracy')

In [12]:
with strategy.scope():
    model = create_model()
    optimizer = tf.keras.optimizers.Adam()
    checkpoint = tf.train.Checkpoint(optimizer = optimizer, model = model)

In [13]:
with strategy.scope():
    def train_step(inputs):
        images, labels = inputs
        
        with tf.GradientTape() as tape:
            predictions = model(images, training=True)
            loss = compute_loss(labels, predictions)
            
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))
        
        train_accuracy.update_state(labels, predictions)
        return loss
    
    def test_step(inputs):
        images, labels = inputs
        
        predictions = model(images, training = False)
        t_loss = loss_object(labels, predictions)
        
        test_loss.update_state(t_loss)
        test_accuracy.update_state(labels, predictions)

In [14]:
with strategy.scope():
    @tf.function
    def distributed_train_step(dataset_inputs):
        per_replica_losses = strategy.experimental_run_v2(train_step, args=(dataset_inputs,))
        return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
    
    @tf.function
    def distributed_test_step(dataset_inputs):
        return strategy.experimental_run_v2(test_step, args=(dataset_inputs,))
    
    for epoch in range(EPOCHS):
        total_loss = 0.0
        num_batches = 0
        for x in train_dist_dataset:
            total_loss += distributed_train_step(x)
            num_batches += 1
        train_loss = total_loss / num_batches
        
        for x in test_dist_dataset:
            distributed_test_step(x)
            
        if epoch % 2 ==0:
            checkpoint.save(checkpoint_prefix)
            
        template = ("Epoch {}, Loss: {}, Accuracy: {}, Test_Loss: {}, Test_Accuracy: {}")
        print(template.format(epoch + 1, train_loss, train_accuracy.result()*100, test_loss.result(), test_accuracy.result()*100))
        
        test_loss.reset_states()
        train_accuracy.reset_states()
        test_accuracy.reset_states()

Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.
Instructions for updating:
renamed to `run`
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Epoch 1, Loss: 0.5057141184806824,

In [15]:
eval_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='eval_accuracy')

new_model = create_model()
new_optimizer = tf.keras.optimizers.Adam()

test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(GLOBAL_BATCH_SIZE)

In [16]:
@tf.function
def eval_step(images, labels):
    predictions = new_model(images, training=False)
    eval_accuracy(labels, predictions)

In [17]:
checkpoint = tf.train.Checkpoint(optimizer = new_optimizer, model = new_model)
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

for images, labels in test_dataset:
    eval_step(images, labels)
    
print('Restored model not using strategy: {}'.format(eval_accuracy.result()*100))

Restored model not using strategy: 90.58999633789062


In [18]:
with strategy.scope():
    for _ in range(EPOCHS):
        total_loss = 0.0
        num_batches = 0
        train_iter = iter(train_dist_dataset)
        
        for _ in range(10):
            total_loss += distributed_train_step(next(train_iter))
            num_batches += 1
        average_train_loss = total_loss / num_batches
        
        template=("Epoch {}, Loss: {}, Accuracy: {}")
        print(template.format(epoch +1, average_train_loss, train_accuracy.result()*100))
        train_accuracy.reset_states()

Epoch 10, Loss: 0.15883013606071472, Accuracy: 94.53125
Epoch 10, Loss: 0.12042137235403061, Accuracy: 95.9375
Epoch 10, Loss: 0.11272521317005157, Accuracy: 96.09375
Epoch 10, Loss: 0.12551774084568024, Accuracy: 96.09375
Epoch 10, Loss: 0.16023202240467072, Accuracy: 94.6875
Epoch 10, Loss: 0.1287182867527008, Accuracy: 95.0
Epoch 10, Loss: 0.12287051975727081, Accuracy: 95.3125
Epoch 10, Loss: 0.11338094621896744, Accuracy: 95.15625
Epoch 10, Loss: 0.13191018998622894, Accuracy: 94.53125
Epoch 10, Loss: 0.11839397996664047, Accuracy: 94.84375


In [19]:
with strategy.scope():
    @tf.function
    def distributed_train_epoch(dataset):
        total_loss = 0.0
        num_batches = 0
        for x in dataset:
            per_replica_losses = strategy.experimental_run_v2(train_step, args=(x, ))
            total_loss += strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses, axis=None)
            num_batches += 1
        return total_loss / tf.cast(num_batches, dtype=tf.float32)
    
    for epoch in range(EPOCHS):
        train_loss = distributed_train_epoch(train_dist_dataset)
        
        template = ("Epoch {}, LOSS: {}, Accuracy: {}")
        print(template.format(epoch + 1, train_loss, train_accuracy.result()*100))
        
        train_accuracy.reset_states()

Epoch 1, LOSS: 0.14341290295124054, Accuracy: 94.69000244140625
Epoch 2, LOSS: 0.13122905790805817, Accuracy: 95.11833190917969
Epoch 3, LOSS: 0.12313083559274673, Accuracy: 95.36332702636719
Epoch 4, LOSS: 0.11050769686698914, Accuracy: 95.86499786376953
Epoch 5, LOSS: 0.10170077532529831, Accuracy: 96.25333404541016
Epoch 6, LOSS: 0.09410367161035538, Accuracy: 96.37667083740234
Epoch 7, LOSS: 0.08666103333234787, Accuracy: 96.7366714477539
Epoch 8, LOSS: 0.0795668363571167, Accuracy: 97.0
Epoch 9, LOSS: 0.07453759014606476, Accuracy: 97.16999816894531
Epoch 10, LOSS: 0.06742487847805023, Accuracy: 97.53166198730469
