##### Copyright 2022 The TensorFlow Authors.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Load LM Checkpoints using Model Garden

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://www.tensorflow.org/tfmodels/nlp/load_lm_ckpts"><img src="https://www.tensorflow.org/images/tf_logo_32px.png" />View on TensorFlow.org</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/tensorflow/models/blob/master/docs/nlp/load_lm_ckpts.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
  <td>
    <a href="https://storage.googleapis.com/tensorflow_docs/models/docs/nlp/load_lm_ckpts.ipynb"><img src="https://www.tensorflow.org/images/download_logo_32px.png" />Download notebook</a>
  </td>
</table>

This tutorial demonstrates how to load BERT, ALBERT and ELECTRA pretrained checkpoints and use them for downstream tasks.

[Model Garden](https://www.tensorflow.org/tfmodels) contains a collection of state-of-the-art models, implemented with TensorFlow's high-level APIs. The implementations demonstrate the best practices for modeling, letting users to take full advantage of TensorFlow for their research and product development.

## Install TF Model Garden package

In [None]:
!pip install -U -q "tf-models-official"

## Import necessary libraries

In [None]:
import os
import yaml
import json

import tensorflow as tf

In [None]:
import tensorflow_models as tfm

from official.core import exp_factory

## Load BERT model pretrained checkpoints

### Select required BERT model

In [None]:
# @title Download Checkpoint of the Selected Model { display-mode: "form", run: "auto" }
model_display_name = 'BERT-base cased English'  # @param ['BERT-base uncased English','BERT-base cased English','BERT-large uncased English', 'BERT-large cased English', 'BERT-large, Uncased (Whole Word Masking)', 'BERT-large, Cased (Whole Word Masking)', 'BERT-base MultiLingual','BERT-base Chinese']

if model_display_name == 'BERT-base uncased English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-12_H-768_A-12.tar.gz"
  !tar -xvf "uncased_L-12_H-768_A-12.tar.gz"
elif model_display_name == 'BERT-base cased English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-12_H-768_A-12.tar.gz"
  !tar -xvf "cased_L-12_H-768_A-12.tar.gz"
elif model_display_name == "BERT-large uncased English":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/uncased_L-24_H-1024_A-16.tar.gz"
  !tar -xvf "uncased_L-24_H-1024_A-16.tar.gz"
elif model_display_name == "BERT-large cased English":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/cased_L-24_H-1024_A-16.tar.gz"
  !tar -xvf "cased_L-24_H-1024_A-16.tar.gz"
elif model_display_name == "BERT-large, Uncased (Whole Word Masking)":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_uncased_L-24_H-1024_A-16.tar.gz"
  !tar -xvf "wwm_uncased_L-24_H-1024_A-16.tar.gz"
elif model_display_name == "BERT-large, Cased (Whole Word Masking)":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/wwm_cased_L-24_H-1024_A-16.tar.gz"
  !tar -xvf "wwm_cased_L-24_H-1024_A-16.tar.gz"
elif model_display_name == "BERT-base MultiLingual":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/multi_cased_L-12_H-768_A-12.tar.gz"
  !tar -xvf "multi_cased_L-12_H-768_A-12.tar.gz"
elif model_display_name == "BERT-base Chinese":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/bert/v3/chinese_L-12_H-768_A-12.tar.gz"
  !tar -xvf "chinese_L-12_H-768_A-12.tar.gz"

In [None]:
# Lookup table of the directory name corresponding to each model checkpoint
folder_bert_dict = {
    'BERT-base uncased English': 'uncased_L-12_H-768_A-12',
    'BERT-base cased English': 'cased_L-12_H-768_A-12',
    'BERT-large uncased English': 'uncased_L-24_H-1024_A-16',
    'BERT-large cased English': 'cased_L-24_H-1024_A-16',
    'BERT-large, Uncased (Whole Word Masking)': 'wwm_uncased_L-24_H-1024_A-16',
    'BERT-large, Cased (Whole Word Masking)': 'wwm_cased_L-24_H-1024_A-16',
    'BERT-base MultiLingual': 'multi_cased_L-12_H-768_A-1',
    'BERT-base Chinese': 'chinese_L-12_H-768_A-12'
}

folder_bert = folder_bert_dict.get(model_display_name)
folder_bert

### Construct BERT Model Using the New `params.yaml`

params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here.

In [None]:
config_file = os.path.join(folder_bert, "params.yaml")
config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())
config_dict

In [None]:
# Method 1: pass encoder config dict into EncoderConfig
encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict["task"]["model"]["encoder"])
encoder_config.get().as_dict()

In [None]:
# Method 2: use override_params_dict function to override default Encoder params
encoder_config = tfm.nlp.encoders.EncoderConfig()
tfm.hyperparams.override_params_dict(encoder_config, config_dict["task"]["model"]["encoder"], is_strict=True)
encoder_config.get().as_dict()

### Construct BERT Model Using the Old `bert_config.json`

In [None]:
bert_config_file = os.path.join(folder_bert, "bert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(bert_config_file).read())
config_dict

In [None]:
encoder_config = tfm.nlp.encoders.EncoderConfig({
    'type':'bert',
    'bert': config_dict
})

encoder_config.get().as_dict()

### Construct a classifier with `encoder_config`

Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head.

In [None]:
bert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)
bert_classifier = tfm.nlp.models.BertClassifier(network=bert_encoder, num_classes=2)

tf.keras.utils.plot_model(bert_classifier)

### Load Pretrained Weights into the BERT Classifier

The provided pretrained checkpoint only contains weights for the BERT Encoder within the BERT Classifier. Weights for the Classification Head is still randomly initialized.

In [None]:
checkpoint = tf.train.Checkpoint(encoder=bert_encoder)
checkpoint.read(
    os.path.join(folder_bert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()

## Load ALBERT model pretrained checkpoints

In [None]:
# @title Download Checkpoint of the Selected Model { display-mode: "form", run: "auto" }
albert_model_display_name = 'ALBERT-xxlarge English'  # @param ['ALBERT-base English', 'ALBERT-large English', 'ALBERT-xlarge English', 'ALBERT-xxlarge English']

if albert_model_display_name == 'ALBERT-base English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_base.tar.gz"
  !tar -xvf "albert_base.tar.gz"
elif albert_model_display_name == 'ALBERT-large English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_large.tar.gz"
  !tar -xvf "albert_large.tar.gz"
elif albert_model_display_name == "ALBERT-xlarge English":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xlarge.tar.gz"
  !tar -xvf "albert_xlarge.tar.gz"
elif albert_model_display_name == "ALBERT-xxlarge English":
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/albert/albert_xxlarge.tar.gz"
  !tar -xvf "albert_xxlarge.tar.gz"

In [None]:
# Lookup table of the directory name corresponding to each model checkpoint
folder_albert_dict = {
    'ALBERT-base English': 'albert_base',
    'ALBERT-large English': 'albert_large',
    'ALBERT-xlarge English': 'albert_xlarge',
    'ALBERT-xxlarge English': 'albert_xxlarge'
}

folder_albert = folder_albert_dict.get(albert_model_display_name)
folder_albert

### Construct ALBERT Model Using the New `params.yaml`

params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here.

In [None]:
config_file = os.path.join(folder_albert, "params.yaml")
config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())
config_dict

In [None]:
# Method 1: pass encoder config dict into EncoderConfig
encoder_config = tfm.nlp.encoders.EncoderConfig(config_dict["task"]["model"]["encoder"])
encoder_config.get().as_dict()

In [None]:
# Method 2: use override_params_dict function to override default Encoder params
encoder_config = tfm.nlp.encoders.EncoderConfig()
tfm.hyperparams.override_params_dict(encoder_config, config_dict["task"]["model"]["encoder"], is_strict=True)
encoder_config.get().as_dict()

### Construct ALBERT Model Using the Old `albert_config.json`

In [None]:
albert_config_file = os.path.join(folder_albert, "albert_config.json")
config_dict = json.loads(tf.io.gfile.GFile(albert_config_file).read())
config_dict

In [None]:
encoder_config = tfm.nlp.encoders.EncoderConfig({
    'type':'albert',
    'albert': config_dict
})

encoder_config.get().as_dict()

### Construct a Classifier with `encoder_config`

Here, we construct a new BERT Classifier with 2 classes and plot its model architecture. A BERT Classifier consists of a BERT encoder using the selected encoder config, a Dropout layer and a MLP classification head.

In [None]:
albert_encoder = tfm.nlp.encoders.build_encoder(encoder_config)
albert_classifier = tfm.nlp.models.BertClassifier(network=albert_encoder, num_classes=2)

tf.keras.utils.plot_model(albert_classifier)

### Load Pretrained Weights into the Classifier

The provided pretrained checkpoint only contains weights for the ALBERT Encoder within the ALBERT Classifier. Weights for the Classification Head is still randomly initialized.

In [None]:
checkpoint = tf.train.Checkpoint(encoder=albert_encoder)
checkpoint.read(
    os.path.join(folder_albert, 'bert_model.ckpt')).expect_partial().assert_existing_objects_matched()

## Load ELECTRA model pretrained checkpoints

In [None]:
# @title Download Checkpoint of the Selected Model { display-mode: "form", run: "auto" }
electra_model_display_name = 'ELECTRA-small English'  # @param ['ELECTRA-small English', 'ELECTRA-base English']

if electra_model_display_name == 'ELECTRA-small English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/electra/small.tar.gz"
  !tar -xvf "small.tar.gz"
elif electra_model_display_name == 'ELECTRA-base English':
  !wget "https://storage.googleapis.com/tf_model_garden/nlp/electra/base.tar.gz"
  !tar -xvf "base.tar.gz"

In [None]:
# Lookup table of the directory name corresponding to each model checkpoint
folder_electra_dict = {
    'ELECTRA-small English': 'small',
    'ELECTRA-base English': 'base'
}

folder_electra = folder_electra_dict.get(electra_model_display_name)
folder_electra

### Construct BERT Model Using the `params.yaml`

params.yaml can be used for training with the bundled trainer in addition to constructing the BERT encoder here.

In [None]:
config_file = os.path.join(folder_electra, "params.yaml")
config_dict = yaml.safe_load(tf.io.gfile.GFile(config_file).read())
config_dict

In [None]:
disc_encoder_config = tfm.nlp.encoders.EncoderConfig(
    config_dict['model']['discriminator_encoder']
)

disc_encoder_config.get().as_dict()

### Construct a Classifier with `encoder_config`

Here, we construct a Classifier with 2 classes and plot its model architecture. A Classifier consists of a ELECTRA discriminator encoder using the selected encoder config, a Dropout layer and a MLP classification head.

**Note**: The generator is discarded and the discriminator is used for downstream tasks

In [None]:
disc_encoder = tfm.nlp.encoders.build_encoder(disc_encoder_config)
elctra_dic_classifier = tfm.nlp.models.BertClassifier(network=disc_encoder, num_classes=2)
tf.keras.utils.plot_model(elctra_dic_classifier)

### Load Pretrained Weights into the Classifier

The provided pretrained checkpoint contains weights for the entire ELECTRA model. We are only loading its discriminator (conveninently named as `encoder`) wights within the Classifier. Weights for the Classification Head is still randomly initialized.

In [None]:
checkpoint = tf.train.Checkpoint(encoder=disc_encoder)
checkpoint.read(
    tf.train.latest_checkpoint(os.path.join(folder_electra))
    ).expect_partial().assert_existing_objects_matched()