In [1]:
try:
  %tensorflow_version 2.x
except:
  pass

TensorFlow 2.x selected.


In [0]:
from typing import Dict, Union, Tuple
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [0]:
ds, info = tfds.load('cifar10', split='train', with_info=True)

In [0]:
def preprocessing(raw: Dict[str, tf.Tensor]) -> Tuple[tf.Tensor, int]:
  image = tf.cast(raw['image'], tf.float32) / 256  
  return image, raw['label']

def augmentation(image: tf.Tensor, label: int) -> Tuple[tf.Tensor, int]:
  return image, label

In [0]:
batch_size = 64
train_ds_size = int(info.splits['train'].num_examples * 0.75)
val_ds_size = info.splits['train'].num_examples - train_ds_size

train_ds = ds.take(train_ds_size)
val_ds = ds.skip(train_ds_size)

train_ds = (train_ds
            .batch(batch_size)
            .take(1) # single batch
            .repeat()
            .map(preprocessing,
                 num_parallel_calls=tf.data.experimental.AUTOTUNE)
            .map(augmentation,
                 num_parallel_calls=tf.data.experimental.AUTOTUNE))

val_ds = (val_ds
          .batch(batch_size)
          .repeat()
          .map(preprocessing,
               num_parallel_calls=tf.data.experimental.AUTOTUNE))

In [0]:
@tf.function
def mixup_augmentation(image, label, alpha: float = 1.0) -> Dict[str, Union[tf.Tensor, Tuple[tf.Tensor]]]:
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    
    batch_size = len(label)
    
    indices = tf.range(batch_size)
    indices = tf.random.shuffle(indices)
    
    mixed_image = lam * image + (1 - lam) * tf.gather(image, indices)
    mixed_label = tf.gather(label, indices)
    
    return mixed_image, (mixed_label, label, lam)

In [0]:
train_ds = train_ds.map(mixup_augmentation,
                        num_parallel_calls=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.map(lambda x, y: mixup_augmentation(x, y, alpha=0),
                    num_parallel_calls=tf.data.experimental.AUTOTUNE)

In [0]:
class MixupLoss(tf.keras.losses.Loss):
    def __init__(self, base_loss: tf.keras.losses.Loss):
        name = 'mixup_' + base_loss.name
        super().__init__(name=name, reduction=base_loss.reduction)
        self.base_loss = base_loss
        
    def call(self, y_true, y_pred):
        y_true_mixed = y_true[0]
        lam = y_true[2]
        y_true = y_true[1]

        print(lam)
        return lam * self.base_loss(y_true, y_pred) + (1 - lam) * self.base_loss(y_true_mixed, y_pred)

    def get_config(self):
        return self.base_loss.get_config()

In [0]:
from tensorflow.keras import layers

model = tf.keras.Sequential([
    layers.Conv2D(16, 3, activation='relu', padding='same',
                  input_shape=(32, 32, 3)),
    layers.Conv2D(16, 3, activation='relu', padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(32, 3, activation='relu', padding='same'),
    layers.Conv2D(32, 3, activation='relu', padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(64, 3, activation='relu', padding='same'),
    layers.Conv2D(64, 3, activation='relu', padding='same'),
    layers.Flatten(),
    layers.Dense(128 * 2, activation='relu'),
    layers.Dense(128 * 2, activation='relu'),
    layers.Dense(10, activation='softmax')
])

In [0]:
a, b = next(iter(train_ds))
c, d, e = b

In [35]:
model.compile(tf.keras.optimizers.Adam(5e-3),
              MixupLoss(tf.keras.losses.SparseCategoricalCrossentropy()))

Tensor("loss_5/dense_17_loss/strided_slice_1:0", shape=(None,), dtype=float32)


In [36]:
model.fit(train_ds, validation_data=val_ds, epochs=100,
          steps_per_epoch=1, validation_steps=val_ds_size // batch_size)

ValueError: ignored

In [10]:
# overfit without Mixup 
# model.fit(train_ds, validation_data=val_ds, epochs=100,
#          steps_per_epoch=1, validation_steps=val_ds_size // batch_size)

Train for 1 steps, validate for 195 steps
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch

KeyboardInterrupt: ignored