In [None]:
pip install tensorflow-addons

In [None]:
from typing import List

import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
import tensorflow_hub as hub
from tensorflow import keras
from tensorflow.keras import layers

tfds.disable_progress_bar()
tf.keras.utils.set_random_seed(42)

In [None]:
MODEL_TYPE = 'deit_distilled_tiny_patch16_224'
RESOLUTION = 224
PATCH_SIZE = 16
NUM_PATCHES = (RESOLUTION // PATCH_SIZE) ** 2
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 192
NUM_HEADS = 3
NUM_LAYERS = 12
MLP_UNITS = [
    PROJECTION_DIM * 4,
    PROJECTION_DIM
]
DROPOUT_RATE = 0.0
DROP_PATH_RATE = 0.1

NUM_EPOCHS = 20
BASE_LR = 0.0005
WEIGHT_DECAY = 0.0001

BATCH_SIZE = 256
AUTO = tf.data.AUTOTUNE
NUM_CLASSES = 5

## Dataset preparation

In [None]:
def preprocess_dataset(is_training=True):
  def fn(image, label):
    if is_training:
      # Resize to a bigger spatial resolution and take the random crops.
      image = tf.image.resize(image, (RESOLUTION + 20, RESOLUTION + 20))
      image = tf.image.random_crop(image, (RESOLUTION, RESOLUTION, 3))
      image = tf.image.random_flip_left_right(image)
    else:
      image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
    label = tf.one_hot(label, depth=NUM_CLASSES)
    return image, label

  return fn

In [None]:
def prepare_dataset(dataset, is_training=True):
  if is_training:
    dataset = dataset.shuffle(BATCH_SIZE * 10)
  dataset = dataset.map(preprocess_dataset(is_training), num_parallel_calls=AUTO)
  return dataset.batch(BATCH_SIZE).prefetch(AUTO)

In [None]:
train_dataset, val_dataset = tfds.load(
    "tf_flowers", split=['train[:90%]', 'train[90%:]'], as_supervised=True
)
num_train = train_dataset.cardinality()
num_val = val_dataset.cardinality()
print(f'num of train: {num_train}')
print(f'num of val: {num_val}')

train_dataset = prepare_dataset(train_dataset, is_training=True)
val_dataset = prepare_dataset(val_dataset, is_training=False)