In [1]:
import tensorflow as tf
import argparse
import os
from tensorflow import keras
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
from tensorflow.keras.utils import multi_gpu_model
from tensorflow.keras.optimizers import Adam, SGD

HEIGHT = 32
WIDTH  = 32
DEPTH  = 3
NUM_CLASSES = 10
NUM_TRAIN_IMAGES = 40000
NUM_VALID_IMAGES = 10000
NUM_TEST_IMAGES  = 10000

In [2]:
def train_preprocess_fn(image):

    # Resize the image to add four extra pixels on each side.
    image = tf.image.resize_image_with_crop_or_pad(image, HEIGHT + 8, WIDTH + 8)

    # Randomly crop a [HEIGHT, WIDTH] section of the image.
    image = tf.random_crop(image, [HEIGHT, WIDTH, DEPTH])

    # Randomly flip the image horizontally.
    image = tf.image.random_flip_left_right(image)

    return image

In [3]:
def make_batch(filenames, batch_size):
    """Read the images and labels from 'filenames'."""
    # Repeat infinitely.
    dataset = tf.data.TFRecordDataset(filenames).repeat()

    # Parse records.
    dataset = dataset.map(single_example_parser, num_parallel_calls=1)

    # Batch it up.
    dataset = dataset.batch(batch_size, drop_remainder=True)
    iterator = dataset.make_one_shot_iterator()

    image_batch, label_batch = iterator.get_next()
    return image_batch, label_batch

In [4]:
def single_example_parser(serialized_example):
    """Parses a single tf.Example into image and label tensors."""
    # Dimensions of the images in the CIFAR-10 dataset.
    # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the
    # input format.
    features = tf.parse_single_example(
        serialized_example,
        features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
        })
    image = tf.decode_raw(features['image'], tf.uint8)
    image.set_shape([DEPTH * HEIGHT * WIDTH])

    # Reshape from [depth * height * width] to [depth, height, width].
    image = tf.cast(
        tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
        tf.float32)
    label = tf.cast(features['label'], tf.int32)
    
    image = train_preprocess_fn(image)
    label = tf.one_hot(label, NUM_CLASSES)
    
    return image, label

In [5]:
def cifar10_model(input_shape):

    input_tensor = Input(shape=input_shape)
    base_model = keras.applications.resnet50.ResNet50(include_top=False,
                                                      weights='imagenet',
                                                      input_tensor=input_tensor,
                                                      input_shape=input_shape,
                                                      classes=None)

    x = base_model.output
    x = Flatten()(x)
    predictions = Dense(10, activation='softmax')(x)
    mdl = Model(inputs=base_model.input, outputs=predictions)
    return mdl

In [6]:
#%%
def main(args):
    # Hyper-parameters
    epochs = args.epochs
    lr = args.learning_rate
    batch_size = args.batch_size
    momentum = args.momentum
    weight_decay = args.weight_decay
    optimizer = args.optimizer

    # Data directories and other options
    gpu_count = args.gpu_count
    training_dir = args.training
    validation_dir = args.validation
    eval_dir = args.eval

    train_dataset = make_batch(training_dir,  batch_size)
    val_dataset = make_batch(validation_dir, batch_size)
    eval_dataset = make_batch(eval_dir, batch_size)

    input_shape = (HEIGHT, WIDTH, DEPTH)
    model = cifar10_model(input_shape)

    # Multi-GPU training
    if gpu_count > 1:
        model = multi_gpu_model(model, gpus=gpu_count)

    # Optimizer
    if optimizer.lower() == 'sgd':
        opt = SGD(lr=lr, decay=weight_decay, momentum=momentum)
    else:
        opt = Adam(lr=lr, decay=weight_decay)

    # Compile model
    model.compile(optimizer=opt,
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # Train model
    history = model.fit(x=train_dataset[0], y=train_dataset[1],
                        steps_per_epoch=NUM_TRAIN_IMAGES // batch_size,
                        validation_data=val_dataset,
                        validation_steps=NUM_VALID_IMAGES // batch_size,
                        epochs=epochs)

    # Evaluate model performance
    score = model.evaluate(eval_dataset[0],
                           eval_dataset[1],
                           steps=NUM_TEST_IMAGES // args.batch_size,
                           verbose=0)
    print('Test loss    :', score[0])
    print('Test accuracy:', score[1])

    # Save model to model directory
    #tf.contrib.saved_model.save_keras_model(model, args.model_output_dir)

In [7]:
%%time
if __name__ == "__main__":
    
    parser = argparse.ArgumentParser()

    # Hyper-parameters
    parser.add_argument('--epochs',        type=int,   default=25)
    parser.add_argument('--learning-rate', type=float, default=0.01)
    parser.add_argument('--batch-size',    type=int,   default=128)
    parser.add_argument('--weight-decay',  type=float, default=2e-4)
    parser.add_argument('--momentum',      type=float, default='0.9')
    parser.add_argument('--optimizer',     type=str,   default='sgd')

    # Data directories and other options
    parser.add_argument('--gpu-count',        type=int,   default=4)
    parser.add_argument('--model_output_dir', type=str,   default='./models')
    parser.add_argument('--training',      type=str,   default='data/train/train.tfrecords')
    parser.add_argument('--validation',    type=str,   default='data/validation/validation.tfrecords')
    parser.add_argument('--eval',          type=str,   default='data/eval/eval.tfrecords')
    
    args = parser.parse_args(args=[])
    main(args)

W1014 01:56:00.487950 140411681085184 deprecation.py:323] From <ipython-input-3-281db9a6ad1c>:11: DatasetV1.make_one_shot_iterator (from tensorflow.python.data.ops.dataset_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use `for ... in dataset:` to iterate over a dataset. If using `tf.estimator`, return the `Dataset` object directly from your input function. As a last resort, you can use `tf.compat.v1.data.make_one_shot_iterator(dataset)`.
W1014 01:56:08.129960 140411681085184 deprecation.py:506] From /home/ubuntu/anaconda3/envs/tensorflow_p36/lib/python3.6/site-packages/tensorflow/python/ops/init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
W1014 01:56:15.679027 140411681085184 deprecation.py:323] From /home/ubuntu/anaconda3/en

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25
Epoch 8/25
Epoch 9/25
Epoch 10/25
Epoch 11/25
Epoch 12/25
Epoch 13/25
Epoch 14/25
Epoch 15/25
Epoch 16/25
Epoch 17/25
Epoch 18/25
Epoch 19/25
Epoch 20/25
Epoch 21/25
Epoch 22/25
Epoch 23/25
Epoch 24/25
Epoch 25/25
Test loss    : 1.0591415632993748
Test accuracy: 0.63000804
CPU times: user 1h 16min 35s, sys: 13min 54s, total: 1h 30min 29s
Wall time: 35min 32s


In [8]:
# Single gpu
# CPU times: user 3min 47s, sys: 53.3 s, total: 4min 40s
# Wall time: 2min 48s

# 4 gpu
# CPU times: user 16min 35s, sys: 2min 47s, total: 19min 22s
# Wall time: 8min 12s