In [3]:
from typing import Dict, Union, Tuple
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds

In [2]:
!pip install tensorflow_datasets

Defaulting to user installation because normal site-packages is not writeable
Collecting tensorflow_datasets
  Using cached tensorflow_datasets-2.1.0-py3-none-any.whl (3.1 MB)
Collecting promise
  Using cached promise-2.3.tar.gz (19 kB)
Collecting tensorflow-metadata
  Using cached tensorflow_metadata-0.21.1-py2.py3-none-any.whl (31 kB)
Collecting dill
  Using cached dill-0.3.1.1.tar.gz (151 kB)
Collecting googleapis-common-protos
  Using cached googleapis-common-protos-1.51.0.tar.gz (35 kB)
Building wheels for collected packages: promise, dill, googleapis-common-protos
  Building wheel for promise (setup.py) ... [?25ldone
[?25h  Created wheel for promise: filename=promise-2.3-py3-none-any.whl size=21495 sha256=a3ac4960a0ee4264a2ea62b378988ad5753a7a98efd455234ab7d9e62d33b427
  Stored in directory: /home/evida/.cache/pip/wheels/59/9a/1d/3f1afbbb5122d0410547bf9eb50955f4a7a98e53a6d8b99bd1
  Building wheel for dill (setup.py) ... [?25ldone
[?25h  Created wheel for dill: filename=dill-0

In [2]:
ds, info = tfds.load('cifar10', split='train', with_info=True)

In [3]:
def preprocessing(raw: Dict[str, tf.Tensor]) -> Dict[str, tf.Tensor]:
    image = tf.cast(raw['image'], tf.float32) / 256
    
    return image, raw['label']

In [4]:
batch_size = 64

ds = (ds
      .repeat()
      .batch(batch_size)
      .map(preprocessing,
           num_parallel_calls=tf.data.experimental.AUTOTUNE))

In [5]:
@tf.function
def mixup_augmentation(image, label, alpha: float = 1.0) -> Dict[str, Union[tf.Tensor, Tuple[tf.Tensor]]]:
    if alpha > 0.:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1.
    
    batch_size = len(label)
    
    indices = tf.range(batch_size)
    indices = tf.random.shuffle(indices)
    
    mixed_image = lam * image + (1 - lam) * tf.gather(image, indices)
    mixed_label = tf.gather(label, indices)
    
    return mixed_image, (mixed_label, label, lam)

In [6]:
class MixupLoss(tf.keras.losses.Loss):
    def __init__(self, base_loss: tf.keras.losses.Loss):
        name = 'mixup_' + base_loss.name
        super().__init__(name=name, reduction=base_loss.reduction)
        self.base_loss = base_loss
        
    def call(self, y_true: Tuple[tf.Tensor], y_pred: tf.Tensor) -> tf.Tensor:
        y_true_mixed, y_true, lam = y_true
        return lam * self.base_loss(y_true, y_pred) + (1 - lam) * self.base_loss(y_true_mixed, y_pred)

    def get_config(self):
        return self.base_loss.get_config()

In [45]:
# ds.map(mixup_augmentation)

<DatasetV1Adapter shapes: {image: (None, 32, 32, 3), label: ((None,), (None,), ())}, types: {image: tf.float32, label: (tf.int64, tf.int64, tf.float32)}>

In [7]:
from tensorflow.keras import layers

model = tf.keras.Sequential([
    layers.Conv2D(16, 3, padding='same', input_shape=(32, 32, 3)),
    layers.Conv2D(16, 3, padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(32, 3, padding='same'),
    layers.Conv2D(32, 3, padding='same'),
    layers.MaxPool2D(2),
    layers.Conv2D(64, 3, padding='same'),
    layers.Conv2D(64, 3, padding='same'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(128 * 2),
    layers.Dense(128 * 2),
    layers.Dense(10)
])

In [8]:
model.compile(loss=tf.keras.losses.CategoricalCrossentropy())

In [9]:
model.fit(ds, steps_per_epoch=50000 // batch_size)

Train for 781 steps
  1/781 [..............................] - ETA: 26:16

UnknownError:  Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.
	 [[node sequential/conv2d/Conv2D (defined at <ipython-input-9-dbc6497d913b>:1) ]] [Op:__inference_distributed_function_1640]

Function call stack:
distributed_function
