In [2]:
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import MaxPooling2D, Flatten, Dense, Conv2D, Rescaling
import os
import tensorflow_datasets as tfds

In [44]:
strategy = tf.distribute.MirroredStrategy()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3')


In [90]:
# Load data and distribute it

batch_size = 128
img_height = 224
img_width = 224
scratch = os.environ['SCRATCH']
train_dir = os.path.join(scratch,'imagenette/imagenette2/train/')
val_dir = os.path.join(scratch,'imagenette/imagenette2/val/')

train_dataset = tf.keras.utils.image_dataset_from_directory(
    train_dir,
    labels='inferred',
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=batch_size,
    image_size=(img_height, img_width),
    shuffle=True,
)

val_dataset = tf.keras.utils.image_dataset_from_directory(
    val_dir,
    labels='inferred',
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=batch_size,
    image_size=(img_height, img_width),
    shuffle=True,
)

# with strategy.scope():
  # Create the layer(s) under scope.

def dataset_fn(input_context):
  # a tf.data.Dataset
  dataset = train_dataset

  # Custom your batching, sharding, prefetching, etc.
  global_batch_size = 128
  batch_siz = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = dataset.batch(batch_siz, drop_remainder=True)
  dataset = dataset.shard(
      input_context.num_input_pipelines,
      input_context.input_pipeline_id
  )
  return dataset
# with strategy.scope():
    
# distributed_dataset = strategy.distribute_datasets_from_function(dataset_fn)
distributed_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(val_dataset)


Found 9469 files belonging to 10 classes.
Found 3925 files belonging to 10 classes.


2024-04-30 22:49:38.172469: W tensorflow/core/grappler/optimizers/data/auto_shard.cc:776] AUTO sharding policy will apply DATA sharding policy as it failed to apply FILE sharding policy because of the following reason: Found an unshardable source dataset: name: "TensorSliceDataset/_1"
op: "TensorSliceDataset"
input: "Placeholder/_0"
attr {
  key: "Toutput_types"
  value {
    list {
      type: DT_STRING
    }
  }
}
attr {
  key: "_cardinality"
  value {
    i: 9469
  }
}
attr {
  key: "is_files"
  value {
    b: false
  }
}
attr {
  key: "metadata"
  value {
    s: "\n\026TensorSliceDataset:585"
  }
}
attr {
  key: "output_shapes"
  value {
    list {
      shape {
      }
    }
  }
}
experimental_type {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_DATASET
    args {
      type_id: TFT_PRODUCT
      args {
        type_id: TFT_TENSOR
        args {
          type_id: TFT_STRING
        }
      }
    }
  }
}

2024-04-30 22:49:38.181155: W tensorflow/core/grappler/optimizers/data/au

In [81]:
# Create model

def vgg16():
    model = Sequential([
        Rescaling(1./255),
        Conv2D(64, (3, 3), activation='relu', padding='same', input_shape=(224, 224, 3), strides=1),
        Conv2D(64, (3, 3), activation='relu', padding='same', strides=1),
        MaxPooling2D((2, 2), strides=(2, 2)),
        
        Conv2D(128, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(128, (3, 3), activation='relu', padding='same', strides=1),
        MaxPooling2D((2, 2), strides=(2, 2)),
        
        Conv2D(256, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(256, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(256, (3, 3), activation='relu', padding='same', strides=1),
        MaxPooling2D((2, 2), strides=(2, 2)),
        
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        MaxPooling2D((2, 2), strides=(2, 2)),
        
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        Conv2D(512, (3, 3), activation='relu', padding='same', strides=1),
        MaxPooling2D((2, 2), strides=(2, 2)),
        
        Flatten(),
        Dense(4096, activation='relu'),
        Dense(4096, activation='relu'),
        Dense(1000, activation='softmax')
    ])
    return model

In [82]:
# Create a checkpoint directory to store the checkpoints.
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")

In [91]:
with strategy.scope():
  # Set reduction to `NONE` so you can do the reduction yourself.
  loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
      from_logits=True,
      reduction=tf.keras.losses.Reduction.NONE)
  def compute_loss(labels, predictions, model_losses):
    per_example_loss = loss_object(labels, predictions)
    loss = tf.nn.compute_average_loss(per_example_loss)
    if model_losses:
      loss += tf.nn.scale_regularization_loss(tf.add_n(model_losses))
    return loss

In [92]:
with strategy.scope():
  test_loss = tf.keras.metrics.Mean(name='test_loss')

  train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='train_accuracy')
  test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      name='test_accuracy')

In [97]:
# A model, an optimizer, and a checkpoint must be created under `strategy.scope`.
with strategy.scope():
  model = vgg16()

  optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)

  checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)

In [98]:
def train_step(inputs):
  images, labels = inputs

  with tf.GradientTape() as tape:
    predictions = model(images, training=True)
    loss = compute_loss(labels, predictions, model.losses)

  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

  train_accuracy.update_state(labels, predictions)
  return loss

def test_step(inputs):
  images, labels = inputs

  predictions = model(images, training=False)
  t_loss = loss_object(labels, predictions)

  test_loss.update_state(t_loss)
  test_accuracy.update_state(labels, predictions)

In [99]:
EPOCHS = 10

In [None]:
# `run` replicates the provided computation and runs it
# with the distributed input.
@tf.function
def distributed_train_step(dataset_inputs):
  per_replica_losses = strategy.run(train_step, args=(dataset_inputs,))
  return strategy.reduce(tf.distribute.ReduceOp.SUM, per_replica_losses,
                         axis=None)

@tf.function
def distributed_test_step(dataset_inputs):
  return strategy.run(test_step, args=(dataset_inputs,))

for epoch in range(EPOCHS):
  # TRAIN LOOP
  total_loss = 0.0
  num_batches = 0
  for x in distributed_dataset:
    total_loss += distributed_train_step(x)
    num_batches += 1
  train_loss = total_loss / num_batches

  # TEST LOOP
  for x in test_dist_dataset:
    distributed_test_step(x)

  if epoch % 2 == 0:
    checkpoint.save(checkpoint_prefix)

  template = ("Epoch {}, Loss: {}, Accuracy: {}, Test Loss: {}, "
              "Test Accuracy: {}")
  print(template.format(epoch + 1, train_loss,
                         train_accuracy.result() * 100, test_loss.result(),
                         test_accuracy.result() * 100))

  test_loss.reset_states()
  train_accuracy.reset_states()
  test_accuracy.reset_states()

INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:batch_all_reduce: 32 all-reduces with algorithm = nccl, num_packs = 1
Epoch 1, Loss: 2.8015942573547363, Accuracy: 11.0043306350708, Test Loss: 2.3056492805480957, Test Accuracy: 11.745223045349121
Epoch 2, Loss: 2.2550034523010254, Accuracy: 17.752666473388672, Test Loss: 2.2164299488067627, Test Accuracy: 19.439489364624023
Epoch 3, Loss: 2.0420689582824707, Accuracy: 27.83820915222168, Test Loss: 1.9922412633895874, Test Accuracy: 31.03184700012207
Epoch 4, Loss: 1.7149308919906616, Accuracy: 41.472171783447266, Test Loss: 1.664320707321167, Test Accuracy: 44.356689453125
Epoch 5, Loss: 1.4539991617202759, Accuracy: 51.65275955200195, Test Loss: 1.4173789024353027, Test Accuracy: 52.178340911865234
Epoch 6, Loss: 1.2592525482177734, Accuracy: 58.49614715576172, Test Loss: 1.4011811017990112, Test Accuracy: 53.47770309448242
Epoch 7, Loss: 1.116281270980835, Accuracy: 62.731014251708

In [None]:
gpus = tf.config.list_logical_devices('GPU')
print(gpus)