##### Copyright 2023 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Train Perceiver model for down streaming tasks

This tutorial demonstrates how to use [Perciever](https://arxiv.org/abs/2107.14795) model for down streaming tasks using Tensorflow Model Garden.

[Tensorflow Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.


## Clone models repository

In [None]:
!git clone -q https://github.com/tensorflow/models.git

## Install necessary dependencies

In [None]:
!pip install -q tensorflow==2.13.0
!pip install -q -U tensorflow_datasets
!pip install -q --user -r models/official/requirements.txt
!pip install -q tensorflow-text==2.13.0

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m524.1/524.1 MB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m86.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m95.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m440.8/440.8 kB[0m [31m44.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m241.2/241.2 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m175.1/175.1 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.9 MB/s[0m et

**Note**: Please restart the runtime once libraries are installed

## Please set the Python path with `os.environ` for models directory

In [None]:
import os
os.environ['PYTHONPATH'] += ":/content/models"

import sys
sys.path.append("/content/models")

## Import necessary libraries

In [None]:
import os
import pprint
import tensorflow as tf
import tensorflow_datasets as tfds

from IPython import display
from official.core import task_factory
from official.core import train_lib
from official.core import train_utils
from official.projects.perceiver.tasks import sentence_prediction
from official.projects.perceiver.configs import perceiver as exp_cfg
from official.nlp.modeling.layers import FastWordpieceBertTokenizer
from official.nlp.modeling.layers import BertPackInputs


pp = pprint.PrettyPrinter(indent=4) # Set Pretty Print Indentation
print(tf.__version__) # Check the version of tensorflow used

2.13.0


## Download `glue/mrpc` dataset.

In [None]:
tfds_name = 'glue/mrpc'
ds,ds_info = tfds.load(tfds_name,
                       with_info=True)
ds_info

Downloading and preparing dataset 1.43 MiB (download: 1.43 MiB, generated: 1.74 MiB, total: 3.17 MiB) to /root/tensorflow_datasets/glue/mrpc/2.0.0...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Generating splits...:   0%|          | 0/3 [00:00<?, ? splits/s]

Generating train examples...:   0%|          | 0/3668 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/2.0.0.incompleteGFRJBN/glue-train.tfrecord*...:   0%|          |…

Generating validation examples...:   0%|          | 0/408 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/2.0.0.incompleteGFRJBN/glue-validation.tfrecord*...:   0%|      …

Generating test examples...:   0%|          | 0/1725 [00:00<?, ? examples/s]

Shuffling /root/tensorflow_datasets/glue/mrpc/2.0.0.incompleteGFRJBN/glue-test.tfrecord*...:   0%|          | …

Dataset glue downloaded and prepared to /root/tensorflow_datasets/glue/mrpc/2.0.0. Subsequent calls will reuse this data.


tfds.core.DatasetInfo(
    name='glue',
    full_name='glue/mrpc/2.0.0',
    description="""
    GLUE, the General Language Understanding Evaluation benchmark
    (https://gluebenchmark.com/) is a collection of resources for training,
    evaluating, and analyzing natural language understanding systems.
    """,
    config_description="""
    The Microsoft Research Paraphrase Corpus (Dolan & Brockett, 2005) is a corpus of
    sentence pairs automatically extracted from online news sources, with human annotations
    for whether the sentences in the pair are semantically equivalent.
    """,
    homepage='https://www.microsoft.com/en-us/download/details.aspx?id=52398',
    data_path=PosixGPath('/tmp/tmpyoq7f3i8tfds'),
    file_format=tfrecord,
    download_size=1.43 MiB,
    dataset_size=1.74 MiB,
    features=FeaturesDict({
        'idx': int32,
        'label': ClassLabel(shape=(), dtype=int64, num_classes=2),
        'sentence1': Text(shape=(), dtype=string),
        'sentence2': Tex

## Download bert base checkpoint for vocab file

In [None]:
!wget https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz -O ./uncased_L-12_H-768_A-12.tar.gz
!tar -zxvf ./uncased_L-12_H-768_A-12.tar.gz -C ./
!rm ./uncased_L-12_H-768_A-12.tar.gz

--2023-07-10 17:28:55--  https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz
Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.196.128, 173.194.215.128, 173.194.216.128, ...
Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.196.128|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 405351325 (387M) [application/octet-stream]
Saving to: ‘./uncased_L-12_H-768_A-12.tar.gz’


2023-07-10 17:28:58 (120 MB/s) - ‘./uncased_L-12_H-768_A-12.tar.gz’ saved [405351325/405351325]

uncased_L-12_H-768_A-12/
uncased_L-12_H-768_A-12/vocab.txt
uncased_L-12_H-768_A-12/bert_model.ckpt.index
uncased_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001
uncased_L-12_H-768_A-12/params.yaml
uncased_L-12_H-768_A-12/bert_config.json


## Configure the perceiver model for custom dataset training

In [None]:
gs_folder_bert = "./uncased_L-12_H-768_A-12"

### Load the registered configuration

In [None]:
exp_config = exp_cfg.exp_factory.get_exp_config('perceiver/word_piece_raw_sentence_prediction')

### Change the parameters required to train the model

In [None]:
BATCH_SIZE = 8
epochs = 5
vocab_file = './uncased_L-12_H-768_A-12/vocab.txt'


train_data_size = ds_info.splits['train'].num_examples
validation_data_size = ds_info.splits['validation'].num_examples
steps_per_epoch = int(train_data_size / BATCH_SIZE)
num_train_steps = steps_per_epoch * epochs
validation_steps = int(validation_data_size / BATCH_SIZE)
warmup_steps = int(0.1 * num_train_steps)
initial_learning_rate = 2e-5


exp_config.runtime.num_gpus = 1
exp_config.runtime.enable_xla = False
exp_config.runtime.mixed_precision_dtype = 'mixed_bfloat16'
exp_config.task.model.num_classes = 2

exp_config.task.train_data.tfds_name = 'glue/mrpc'
exp_config.task.train_data.tfds_split = 'train'
exp_config.task.train_data.text_fields = ['sentence1', 'sentence2']
exp_config.task.train_data.global_batch_size = BATCH_SIZE
exp_config.task.train_data.lower_case = True
exp_config.task.train_data.tokenization = 'WordPiece'
exp_config.task.train_data.vocab_file = vocab_file

exp_config.task.validation_data.tfds_name = 'glue/mrpc'
exp_config.task.validation_data.tfds_split = 'validation'
exp_config.task.validation_data.text_fields = ['sentence1', 'sentence2']
exp_config.task.validation_data.global_batch_size = BATCH_SIZE
exp_config.task.validation_data.lower_case = True
exp_config.task.validation_data.tokenization = 'WordPiece'
exp_config.task.validation_data.vocab_file = vocab_file

exp_config.trainer.checkpoint_interval = steps_per_epoch
exp_config.trainer.optimizer_config.learning_rate.polynomial.initial_learning_rate = initial_learning_rate
exp_config.trainer.optimizer_config.learning_rate.polynomial.decay_steps = num_train_steps
exp_config.trainer.optimizer_config.warmup.polynomial.warmup_steps = warmup_steps
exp_config.trainer.steps_per_loop = steps_per_epoch
exp_config.trainer.summary_interval = steps_per_epoch
exp_config.trainer.train_steps = num_train_steps
exp_config.trainer.validation_interval = steps_per_epoch
exp_config.trainer.validation_steps =  validation_steps

### Detect the hardware

In [None]:
try:
  tpu_resolver = tf.distribute.cluster_resolver.TPUClusterResolver() # TPU detection
except ValueError:
  tpu_resolver = None
  gpus = tf.config.experimental.list_logical_devices("GPU")

# Select appropriate distribution strategy
if tpu_resolver:
  tf.config.experimental_connect_to_cluster(tpu_resolver)
  tf.tpu.experimental.initialize_tpu_system(tpu_resolver)
  distribution_strategy = tf.distribute.experimental.TPUStrategy(tpu_resolver)
  print('Running on TPU ', tpu_resolver.cluster_spec().as_dict()['worker'])
elif len(gpus) > 1:
  distribution_strategy = tf.distribute.MirroredStrategy([gpu.name for gpu in gpus])
  print('Running on multiple GPUs ', [gpu.name for gpu in gpus])
elif len(gpus) == 1:
  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on single GPU ', gpus[0].name)
else:
  distribution_strategy = tf.distribute.get_strategy() # default strategy that works on CPU and single GPU
  print('Running on CPU')

print("Number of accelerators: ", distribution_strategy.num_replicas_in_sync)

Running on single GPU  /device:GPU:0
Number of accelerators:  1


### Print the modified configuration.

In [None]:
pp.pprint(exp_config.as_dict())
display.Javascript('google.colab.output.setIframeHeight("500px");')

{   'runtime': {   'all_reduce_alg': None,
                   'batchnorm_spatial_persistent': False,
                   'dataset_num_private_threads': None,
                   'default_shard_dim': -1,
                   'distribution_strategy': 'mirrored',
                   'enable_xla': False,
                   'gpu_thread_mode': None,
                   'loss_scale': None,
                   'mixed_precision_dtype': 'mixed_bfloat16',
                   'num_cores_per_replica': 1,
                   'num_gpus': 1,
                   'num_packs': 1,
                   'per_gpu_thread_count': 0,
                   'run_eagerly': False,
                   'task_index': -1,
                   'tpu': None,
                   'tpu_enable_xla_dynamic_padder': None,
                   'use_tpu_mp_strategy': False,
                   'worker_hosts': None},
    'task': {   'allow_image_summary': False,
                'differential_privacy_config': None,
                'hub_module_url': '',


<IPython.core.display.Javascript object>

## Create the `Task` object (`tfm.core.base_task.Task`) from the `config_definitions.TaskConfig`.

The `Task` object has all the methods necessary for building the dataset, building the model, and running training & evaluation. These methods are driven by `tfm.core.train_lib.run_experiment`.

In [None]:
model_dir = './trained_model/'

with distribution_strategy.scope():
  task = task_factory.get_task(exp_config.task, logging_dir=model_dir)

## Train and Evaluate the model

In [None]:
model, eval_logs = train_lib.run_experiment(
    distribution_strategy=distribution_strategy,
    task=task,
    mode='train_and_eval',
    params=exp_config,
    model_dir=model_dir)

restoring or initializing model...
train | step:      0 | training until step 458...
train | step:    458 | steps/sec:    0.8 | output: 
    {'auc': 0.72914344,
     'cls_accuracy': 0.66348255,
     'learning_rate': 1.6e-05,
     'training_loss': 0.63586426}
saved checkpoint to ./trained_model/ckpt-458.
 eval | step:    458 | running 51 steps of evaluation...
 eval | step:    458 | steps/sec:    2.2 | eval time:   23.1 sec | output: 
    {'auc': 0.7645757,
     'cls_accuracy': 0.67401963,
     'steps_per_second': 2.2112783063481327,
     'validation_loss': 0.6187927}
train | step:    458 | training until step 916...
train | step:    916 | steps/sec:    0.8 | output: 
    {'auc': 0.8580948,
     'cls_accuracy': 0.7363537,
     'learning_rate': 1.2e-05,
     'training_loss': 0.5320095}
saved checkpoint to ./trained_model/ckpt-916.
 eval | step:    916 | running 51 steps of evaluation...
 eval | step:    916 | steps/sec:    2.9 | eval time:   17.8 sec | output: 
    {'auc': 0.800959,
    

## Testing the trained model

### Helper functions for pre-processing test data

In [None]:
tokenizer = FastWordpieceBertTokenizer(
    vocab_file=vocab_file,
    lower_case=exp_config.task.train_data.lower_case
)

packer = BertPackInputs(
    seq_length=exp_config.task.train_data.seq_length,
    special_tokens_dict=tokenizer.get_special_tokens_dict()
)


class BertInputProcessor(tf.keras.layers.Layer):
  def __init__(self, tokenizer, packer):
    super().__init__()
    self.tokenizer = tokenizer
    self.packer = packer

  def call(self, inputs):
    tok1 = self.tokenizer(inputs['sentence1'])
    tok2 = self.tokenizer(inputs['sentence2'])

    packed = self.packer([tok1, tok2])

    if 'label' in inputs:
      return packed, inputs['label']
    else:
      return packed

### Pre-process test data

In [None]:
bert_inputs_processor = BertInputProcessor(
    tokenizer=tokenizer,
    packer=packer
)
test_ds = ds['test'].batch(
    1).map(bert_inputs_processor)

### Get the predictions

In [None]:
for record in ds['test'].batch(1).take(8):
  print(f"Sentence 1:{record['sentence1'].numpy()}")
  print(f"Sentence 2:{record['sentence2'].numpy()}")
  processed_rec = bert_inputs_processor(record)
  prediction = tf.argmax(
      model.predict(processed_rec[0]),
      axis=1)
  print(f"Prediction: {prediction[0]}")

Sentence 1:[b'Shares in BA were down 1.5 percent at 168 pence by 1420 GMT , off a low of 164p , in a slightly stronger overall London market .']
Sentence 2:[b'Shares in BA were down three percent at 165-1 / 4 pence by 0933 GMT , off a low of 164 pence , in a stronger market .']
Prediction: 1
Sentence 1:[b'The South Korean Agriculture and Forestry Ministry also said it would throw out or send back all Canadian beef currently in store .']
Sentence 2:[b'The South Korean Agriculture and Forestry Ministry said it would scrap or return all Canadian beef in store .']
Prediction: 1
Sentence 1:[b'" New Yorkers didn \'t embrace these units like they could have , " said Matthew Daus , chairman of the commission .']
Sentence 2:[b'" New Yorkers didn \'t embrace these units like they could have , " Matthew W. Daus , the commission \'s chairman , said yesterday .']
Prediction: 1
Sentence 1:[b'" I really liked him and I still do , " Cohen Alon told the Herald yesterday .']
Sentence 2:[b'And I really l