In [3]:
from typing import Dict, Union, Tuple
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [4]:
ds, info = tfds.load('cifar10', split='train', with_info=True)

In [5]:
def preprocessing(raw: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
    image = tf.cast(raw['image'], tf.float32) / 256
    
    return image, raw['label']

In [6]:
batch_size = 64

ds = (ds
      .repeat()
      .batch(batch_size)
      .map(preprocessing,
           num_parallel_calls=tf.data.experimental.AUTOTUNE))

In [7]:
@tf.function
def mixup_augmentation(image, label, alpha: float = 1.0) -> Dict[str, Union[tf.Tensor, Tuple[tf.Tensor]]]:
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    
    batch_size = len(label)
    
    indices = tf.range(batch_size)
    indices = tf.random.shuffle(indices)
    
    mixed_image = lam * image + (1 - lam) * tf.gather(image, indices)
    mixed_label = tf.gather(label, indices)
    
    return mixed_image, (mixed_label, label, lam)

In [8]:
class MixupLoss(tf.keras.losses.Loss):
    def __init__(self, base_loss: tf.keras.losses.Loss):
        name = 'mixup_' + base_loss.name
        super().__init__(name=name, reduction=base_loss.reduction)
        self.base_loss = base_loss
        
    def call(self, y_true: Tuple[tf.Tensor], y_pred: tf.Tensor) -> tf.Tensor:
        y_true_mixed, y_true, lam = y_true
        return lam * self.base_loss(y_true, y_pred) + (1 - lam) * self.base_loss(y_true_mixed, y_pred)

    def get_config(self):
        return self.base_loss.get_config()

In [9]:
# ds.map(mixup_augmentation)

In [10]:
from tensorflow.keras import layers

model = tf.keras.Sequential([
    layers.Conv2D(16, 3, padding='same', input_shape=(32, 32, 3)),
    layers.Conv2D(16, 3, padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(32, 3, padding='same'),
    layers.Conv2D(32, 3, padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(64, 3, padding='same'),
    layers.Conv2D(64, 3, padding='same'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(128 * 2),
    layers.Dense(128 * 2),
    layers.Dense(10)
])

In [11]:
model.compile(loss=tf.keras.losses.CategoricalCrossentropy())

In [12]:
model.fit(ds, steps_per_epoch=50000 // batch_size)

Train for 781 steps
  1/781 [..............................] - ETA: 15:57

UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node sequential/conv2d/Conv2D (defined at <ipython-input-12-dbc6497d913b>:1) ]] [Op:__inference_distributed_function_1640]

Function call stack:
distributed_function
