In [1]:
!pip install imgaug -U

Collecting imgaug
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[K     |████████████████████████████████| 948 kB 8.3 MB/s 
Installing collected packages: imgaug
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
albumentations 0.1.12 requires imgaug<0.2.7,>=0.2.5, but you have imgaug 0.4.0 which is incompatible.[0m
Successfully installed imgaug-0.4.0


In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
import imgaug.augmenters as iaa
import imgaug as ia

In [3]:
tf.random.set_seed(42)
ia.seed(42)

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [27]:
AUTO = tf.data.AUTOTUNE
BATCH_SIZE = 128
EPOCHS = 10
IMAGE_SIZE = 72

In [6]:
rand_aug = iaa.RandAugment(n=3, m=7)

In [7]:
def augment(images):
  images = tf.cast(images, np.uint8)
  return rand_aug(images=images.numpy())

In [8]:
train_ds_rand = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .map(
        lambda x, y: (tf.py_function(augment, [x], [tf.float32])[0], y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

In [9]:
test_ds = (
    tf.data.Dataset.from_tensor_slices((x_test, y_test))
    .batch(BATCH_SIZE)
    .map(
        lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
        num_parallel_calls=AUTO,
    )
    .prefetch(AUTO)
)

In [10]:
simple_aug = tf.keras.Sequential(
    [
     tf.keras.layers.Resizing(IMAGE_SIZE, IMAGE_SIZE),
     tf.keras.layers.RandomFlip('horizontal'),
     tf.keras.layers.RandomRotation(factor=0.02),
     tf.keras.layers.RandomZoom(height_factor=0.2, width_factor=0.2)
    ]
)

In [11]:
train_ds_simple = (
    tf.data.Dataset.from_tensor_slices((x_train, y_train))
    .shuffle(BATCH_SIZE * 100)
    .batch(BATCH_SIZE)
    .map(lambda x, y: (simple_aug(x), y), num_parallel_calls=AUTO)
    .prefetch(AUTO)
)

In [12]:
def get_training_model():
  resnet50_v2 = tf.keras.applications.ResNet50V2(
      weights=None,
      include_top=True,
      input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
      classes=10
  )
  model = tf.keras.Sequential(
      [
       tf.keras.layers.Input((IMAGE_SIZE, IMAGE_SIZE, 3)),
       tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1), 
       resnet50_v2,
      ]
  )
  return model

In [13]:
print(get_training_model().summary())

Model: "sequential_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 rescaling (Rescaling)       (None, 72, 72, 3)         0         
                                                                 
 resnet50v2 (Functional)     (None, 10)                23585290  
                                                                 
Total params: 23,585,290
Trainable params: 23,539,850
Non-trainable params: 45,440
_________________________________________________________________
None


In [14]:
initial_model = get_training_model()
initial_model.save_weights('initial_weights.h5')

In [21]:
# Set up an early stopping callback to prevent the model from overfitting
es = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss', patience=10, restore_best_weights=True
)

In [28]:
rand_aug_model = get_training_model()
rand_aug_model.load_weights('initial_weights.h5')
rand_aug_model.compile(
    loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']
)
rand_aug_model.fit(train_ds_rand, validation_data=test_ds, epochs=EPOCHS, callbacks=[es])
_, test_acc = rand_aug_model.evaluate(test_ds)
print('test_acc: {:.2f}%'.format(test_acc * 100))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
test_acc: 78.04%


In [29]:
simple_aug_model = get_training_model()
simple_aug_model.load_weights('initial_weights.h5')
simple_aug_model.compile(
    loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy']
)
simple_aug_model.fit(train_ds_simple, validation_data=test_ds, epochs=EPOCHS, callbacks=[es])
_, test_acc = simple_aug_model.evaluate(test_ds)
print('test_acc: {:.2f}%'.format(test_acc * 100))

Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
test_acc: 82.51%


In [17]:
cifar_10 = tfds.load('cifar10', split='test', as_supervised=True)
cifar_10 = cifar_10.batch(BATCH_SIZE).map(
    lambda x, y: (tf.image.resize(x, (IMAGE_SIZE, IMAGE_SIZE)), y),
    num_parallel_calls=AUTO
)

[1mDownloading and preparing dataset cifar10/3.0.2 (download: 162.17 MiB, generated: 132.40 MiB, total: 294.58 MiB) to /root/tensorflow_datasets/cifar10/3.0.2...[0m


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]






0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteYRV78G/cifar10-train.tfrecord


  0%|          | 0/50000 [00:00<?, ? examples/s]

0 examples [00:00, ? examples/s]

Shuffling and writing examples to /root/tensorflow_datasets/cifar10/3.0.2.incompleteYRV78G/cifar10-test.tfrecord


  0%|          | 0/10000 [00:00<?, ? examples/s]

[1mDataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.[0m


In [30]:
_, test_acc = rand_aug_model.evaluate(cifar_10, verbose=0)
print('acc with randaugment on cifar10: {:.2f}%'.format(test_acc * 100))

acc with randaugment on cifar10: 78.04%


In [31]:
_, test_acc = simple_aug_model.evaluate(cifar_10, verbose=0)
print('acc with simple augment on cifar10: {:.2f}%'.format(test_acc * 100))

acc with simple augment on cifar10: 82.51%
