In [1]:
import tensorflow as tf
import modules.tf.load_img as load_img
import importlib
importlib.reload(load_img)

<module 'modules.tf.load_img' from '/media/taesookim/0C8ECA3C8ECA1E58/Users/terry/source/repos/tensorflow2_notebooks/modules/tf/load_img.py'>

In [2]:
import os
from tensorflow.keras import mixed_precision

In [3]:
os.environ['TF_XLA_FLAGS']="--tf_xla_enable_xla_devices"
policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_global_policy(policy)

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPUs will likely run quickly with dtype policy mixed_float16 as they all have compute capability of at least 7.0


In [4]:
tf.config.experimental.list_physical_devices()


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:1', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:2', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:1', device_type='GPU'),
 PhysicalDevice(name='/physical_device:GPU:2', device_type='GPU')]

In [5]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2')


In [6]:
def seq(lyrs):
  return [tf.keras.models.Sequential(lyr) for lyr in lyrs]

#Note: This design can be improved (i.e Conv->BN->Activation).
def model_autoencoder():
  inp = tf.keras.layers.Input([16, 16, 1])

  layers = seq([tf.keras.layers.Conv2D(16, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#16
                                tf.keras.layers.Conv2D(32, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#8
                                tf.keras.layers.Conv2D(64, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#4
                                tf.keras.layers.Conv2D(128, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#2
                                tf.keras.layers.Conv2D(256, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#1
                                tf.keras.layers.Conv2D(512, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#1
                                tf.keras.layers.Conv2D(256, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#1
                                tf.keras.layers.Conv2DTranspose(128, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#2
                                tf.keras.layers.Conv2DTranspose(64, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#4
                                tf.keras.layers.Conv2DTranspose(32, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.leaky_relu),#8
                                tf.keras.layers.Conv2DTranspose(1, kernel_size=(3,3), strides=2, padding="same", activation=tf.nn.tanh, dtype=tf.float32 ) #16
          ])
  prev = inp
  skips = []
  for layer in layers[:6]:
    prev = layer(prev)
    skips.append(prev)

  skips = skips[:5]
  for skip, layer in zip(reversed(skips), layers[6:]):
    prev = tf.keras.layers.concatenate([skip, prev])
    prev = layer(prev)

  return tf.keras.Model(inputs=inp, outputs=prev)

In [7]:
BATCH_SIZE = 32
GLOBAL_BATCH_SIZE = 32 * strategy.num_replicas_in_sync

In [8]:
train_imgs = load_img.load_mnist(GLOBAL_BATCH_SIZE, tiny=False)

In [9]:
train_imgs = strategy.experimental_distribute_dataset(train_imgs)

In [10]:
# for datamap in train_imgs:
#     img, img_2, img_fn = datamap
#     break

In [11]:
with strategy.scope():
    model = model_autoencoder()
    loss_object = tf.keras.losses.MeanAbsoluteError(reduction=tf.keras.losses.Reduction.NONE)
    def compute_loss(real, pred):
        per_example_loss = loss_object(real, pred)
        return tf.nn.compute_average_loss(per_example_loss, global_batch_size=GLOBAL_BATCH_SIZE)
    optimizer = tf.keras.optimizers.Adam()
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)

In [12]:
@tf.function
def train_step(inputs):
    img, img_fn = inputs
    with tf.GradientTape() as tape:
        pred = model(img * 0.5, training=True)
        loss = compute_loss(img, pred)
        scaled_loss = optimizer.get_scaled_loss(loss)
    scaled_gradients = tape.gradient(scaled_loss, model.trainable_variables)
    gradients = optimizer.get_unscaled_gradients(scaled_gradients)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    return loss

def distributed_train_step(inputs):
    per_replica_losses = strategy.run(train_step, args=(inputs, ))
    return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                           axis=None)

In [13]:
for epoch in range(50):
    for data in train_imgs:
        loss = distributed_train_step(data)
        tf.print(loss)


INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 22 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 22 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
83.8918457
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
86.2702332
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
93.0373383
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job

KeyboardInterrupt: 