# Contents
------------------------------------------------
This is a tutorial focusing on the flow of finetuning TPU with TensorFlow 2.0.

## Install Dependencies

In [1]:
!pip install tensorflow_addons transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensorflow_addons
  Downloading tensorflow_addons-0.19.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m16.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting transformers
  Downloading transformers-4.26.1-py3-none-any.whl (6.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.3/6.3 MB[0m [31m74.3 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m71.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.12.1-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.3/190.3 

# Configure TPU to Maximize Speed

In [2]:
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras import mixed_precision
import tensorflow_addons as tfa
from tensorflow_addons.optimizers import SWA
from tensorflow_addons.optimizers import Lookahead
from tensorflow_addons.optimizers import AdamW
from tensorflow.keras import mixed_precision

In [3]:
use_tpu = True
use_gpu = False
if use_tpu: 
   resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')    
   tf.config.experimental_connect_to_cluster(resolver)
# This is the TPU initialization code that has to be at the beginning.
   tf.tpu.experimental.initialize_tpu_system(resolver)
   print("All devices: ", tf.config.list_logical_devices('TPU'))
   strategy = tf.distribute.TPUStrategy(resolver)

Running on TPU  grpc://10.81.8.138:8470




REPLICAS:  8


In [3]:
if use_tpu:
   policy = mixed_precision.Policy('mixed_bfloat16')
elif use_gpu: 
   policy = mixed_precision.Policy('mixed_float16')
else: 
   policy = mixed_precision.Policy('float32')
mixed_precision.set_global_policy(policy)


In [None]:
XLA_ACCELERATE = True
if XLA_ACCELERATE:
    tf.config.optimizer.set_jit(True)
    print("XLA Enabled")

## Add Data Retrieval

In [None]:
def get_dataset(batch_size, is_training=True):
  split = 'train' if is_training else 'test'
  dataset, info = tfds.load(name='mnist', split=split, with_info=True,
                            as_supervised=True, try_gcs=True)

  # Normalize the input data.
  def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255.0
    return image, label

  dataset = dataset.map(scale, num_parallel_calls = AUTO)

  if is_training:
    dataset = dataset.shuffle(10000)
    dataset = dataset.repeat()

  dataset = dataset.batch(BATCH_SIZE)
  dataset = dataset.prefetch(AUTO)
  return dataset

## Training Template: LR Example

In [None]:
# LR_START = 0.00001
# LR_MAX = 0.00004 * strategy.num_replicas_in_sync
# LR_MIN = 0.00001
# LR_RAMPUP_EPOCHS = 3
# LR_SUSTAIN_EPOCHS = 0
# LR_EXP_DECAY = .7
# @tf.function
# def lrfn(epoch):
#     if float(epoch) < LR_RAMPUP_EPOCHS:
#         lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * float(epoch) + LR_START
#     elif float(epoch) < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS:
#         lr = LR_MAX
#     else:
#         lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(float(epoch) - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN
#     return lr

### This is the final training template: 
WARNMING: Don't use Gradient Accumulation with bfloat15. Doing so will result in considerable trouble. 

In [5]:
# Use this to test:
batch_size = 256
with strategy.scope():
  model = create_model()
  optimizer = tf.keras.optimizers.Adam()
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function #Comment out this functions first
def train_step(iterator):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  strategy.run(step_fn, args=(next(iterator),))

NameError: ignored

## After everything is OK. Start to train multiple steps. 

In [None]:
with strategy.scope():
  model = create_model()
  optimizer_inner = AdamW(weight_decay=1e-6)
  optimizer_middle = SWA(optimizer_inner)
  optimizer = Lookahead(optimizer_middle)
  training_loss = tf.keras.metrics.Mean('training_loss', dtype=tf.float32)
  training_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
      'training_accuracy', dtype=tf.float32)

# Calculate per replica batch size, and distribute the `tf.data.Dataset`s
# on each TPU worker.
actual_batch_size = 128
gradient_accumulation_step = 1
batch_size = actual_batch_size * gradient_accumulation_step
steps_per_epoch = 60000 // batch_size
validation_steps = 10000 // batch_size

train_dataset = get_dataset(batch_size, is_training=True)
per_replica_batch_size = batch_size // strategy.num_replicas_in_sync

train_dataset = strategy.experimental_distribute_datasets_from_function(
    lambda _: get_dataset(per_replica_batch_size, is_training=True))

@tf.function(jit_compile=True)
def train_multiple_steps(iterator, steps):
  """The step function for one training step."""

  def step_fn(inputs):
    """The computation to run on each TPU device."""
    images, labels = inputs
    with tf.GradientTape() as tape:
      logits = model(images, training=True)
      loss = tf.keras.losses.sparse_categorical_crossentropy(
          labels, logits, from_logits=True)
      loss = tf.nn.compute_average_loss(loss, global_batch_size=batch_size)
    grads = tape.gradient(loss, model.trainable_variables)

    optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
    training_loss.update_state(loss * strategy.num_replicas_in_sync)
    training_accuracy.update_state(labels, logits)

  for _ in tf.range(steps):
    strategy.run(step_fn, args=(next(iterator),))

train_iterator = iter(train_dataset)
# Convert `steps_per_epoch` to `tf.Tensor` so the `tf.function` won't get
# retraced if the value changes.

for epoch in range(10):
  print('Epoch: {}/10'.format(epoch))


  train_multiple_steps(train_iterator, tf.convert_to_tensor(steps_per_epoch))

