# Further Pre-training MobileBERT MLM with Federated Averaging

In [1]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [2]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    # Mount Google Drive root directory
    drive.mount('/content/drive')

    # cd to the appropriate working directory under my Google Drive
    %cd '/content/drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

# IPython reloading magic
%load_ext autoreload
%autoreload 2

Collecting tensorflow-federated==0.17.0
[?25l  Downloading https://files.pythonhosted.org/packages/5c/54/900d99d3cff21b6a570281b51f4878a745c0eece7732bb7fc26eee61ef57/tensorflow_federated-0.17.0-py2.py3-none-any.whl (517kB)
[K     |████████████████████████████████| 522kB 7.8MB/s 
[?25hCollecting tensorflow-text==2.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/28/b2/2dbd90b93913afd07e6101b8b84327c401c394e60141c1e98590038060b3/tensorflow_text-2.3.0-cp36-cp36m-manylinux1_x86_64.whl (2.6MB)
[K     |████████████████████████████████| 2.6MB 14.1MB/s 
[?25hCollecting transformers==3.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/2c/4e/4f1ede0fd7a36278844a277f8d53c21f88f37f3754abf76a5d6224f76d4a/transformers-3.4.0-py3-none-any.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 43.7MB/s 
[?25hCollecting tensorflow-addons~=0.11.1
[?25l  Downloading https://files.pythonhosted.org/packages/b3/f8/d6fca180c123f2851035c4493690662ebdad0849a9059d5603543

In [None]:
# Install required packages
!pip install -r requirements.txt

## Import packages

In [3]:
import os
import sys
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_text as tf_text
import transformers

import nest_asyncio
nest_asyncio.apply()

import fedavg
import fedavg_client
import utils

# Random seed settings
random_seed = 692
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

# Tensorflow GPU
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

Num GPUs Available:  1


b'Hello, World!'

In [4]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

Python version: 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0]
NumPy version: 1.18.5
TensorFlow version: 2.3.0
TensorFlow Federated version: 0.17.0
Transformers version: 3.4.0


## Experiment Settings

In [5]:
TOTAL_ROUNDS = 2 # Number of total training rounds
ROUNDS_PER_EVAL = 1 # How often to evaluate

TRAIN_CLIENTS_PER_ROUND = 2 # How many clients to sample per round.
TEST_CLIENTS_PER_ROUND = 2 # How many clients to sample per round for test data

CLIENT_EPOCHS_PER_ROUND = 3 # Number of epochs in the client to take per round.
BATCH_SIZE = 8 # Batch size used on the client.
TEST_BATCH_SIZE = 8 # Minibatch size of test data.

BUFFER_SIZE = 100 # For dataset shuffling

# Maximum length of input token sequence for BERT.
BERT_MAX_SEQ_LENGTH = 128

# Optimizer configuration
SERVER_LEARNING_RATE = 1.0 # Server learning rate.
CLIENT_LEARNING_RATE = 0.000001 # Client learning rate

## Dataset

In [6]:
train_client_data, test_client_data = tff.simulation.datasets.shakespeare.load_data(cache_dir='./tff_cache')

In [7]:
mobilebert_tokenizer = transformers.MobileBertTokenizer.from_pretrained(
    'google/mobilebert-uncased', cache_dir='./transformers_cache')

### Tokenizer

In [8]:
# Imitate transformers tokenizer with TF.Text Tokenizer
mobilebert_vocab_lookup_table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=list(mobilebert_tokenizer.vocab.keys()),
        values=tf.constant(list(mobilebert_tokenizer.vocab.values()), dtype=tf.int64)),
    default_value=0)

mobilebert_special_ids_mask_table = tf.lookup.StaticHashTable(
    tf.lookup.KeyValueTensorInitializer(
        keys=tf.constant(mobilebert_tokenizer.all_special_ids, dtype=tf.int32),
        values=tf.constant(1, dtype=tf.int32, shape=len(mobilebert_tokenizer.all_special_ids)),
        key_dtype=tf.int32, value_dtype=tf.int32),
    default_value=tf.constant(0, dtype=tf.int32))

mobilebert_tokenizer_tf_text = tf_text.BertTokenizer(
    vocab_lookup_table=mobilebert_vocab_lookup_table,
    suffix_indicator="##",
    max_bytes_per_word=mobilebert_tokenizer.wordpiece_tokenizer.max_input_chars_per_word,
    max_chars_per_token=None,
    token_out_type=tf.int32,
    unknown_token=mobilebert_tokenizer.unk_token,
    split_unknown_characters=True,
    lower_case=True,
    keep_whitespace=False,
    normalization_form=None,
    preserve_unused_token=True)

In [9]:
# Test if our new tokenizer works
ttt = mobilebert_tokenizer_tf_text.tokenize("This is a test.")
print("TF Text tokenizer output shape:", tf.shape(ttt.to_tensor()))
print(tf.squeeze(ttt.to_tensor(), axis=-1))
mobilebert_tokenizer.decode(tf.squeeze(ttt, axis=-1).to_list()[0])

TF Text tokenizer output shape: tf.Tensor([1 5 1], shape=(3,), dtype=int32)
tf.Tensor([[2023 2003 1037 3231 1012]], shape=(1, 5), dtype=int32)


'this is a test.'

### Preprocessing

In [10]:
# Based on the answers from
# https://stackoverflow.com/questions/42334646/tensorflow-pad-unknown-size-tensor-to-a-specific-size/51936821#51936821
def dynamic_padding(inp, min_size, constant_values):

    pad_size = min_size - tf.shape(inp)[1]
    paddings = [[0,0], [0, pad_size]] # assign here, during graph execution

    return tf.cast(tf.pad(inp, paddings, constant_values=constant_values), dtype=tf.int32)

# New preprocessing steps based on TF.text tokenizer
def tokenize_and_mask(x):
    # TF.text tokenizer returns RaggedTensor. Convert this to a regular tensor.
    # Note: In the third dimension, 2nd and 3rd indexes contain some sort of offset information,
    # which we will ignore for now.
    tokenized = mobilebert_tokenizer_tf_text.tokenize(tf.reshape(x['snippets'], shape=[1])).to_tensor()[:, :, 0]

    # Add special tokens: [CLS]
    cls_tensor_for_tokenized = tf.constant(mobilebert_tokenizer.cls_token_id, shape=[len(x), 1], dtype=tf.int32)
    tokenized_with_special_tokens = tf.concat([cls_tensor_for_tokenized, tokenized], axis=1)

    # Truncate if the sequence is already longer than BERT_MAX_SEQ_LENGTH
    tokenized_with_special_tokens = tf.cond(
        tf.greater_equal(tf.shape(tokenized_with_special_tokens)[1], BERT_MAX_SEQ_LENGTH),
        true_fn=lambda: tokenized_with_special_tokens[:, 0:BERT_MAX_SEQ_LENGTH-1],
        false_fn=lambda: tokenized_with_special_tokens)     

    # Add special tokens: [SEP]
    sep_tensor_for_tokenized = tf.constant(mobilebert_tokenizer.sep_token_id, shape=[len(x), 1], dtype=tf.int32)
    tokenized_with_special_tokens = tf.concat([tokenized_with_special_tokens, sep_tensor_for_tokenized], axis=1)

    # Padding with [PAD]
    # Final sequence should have the length of BERT_MAX_SEQ_LENGTH
    # Pad only if necessary
    tokenized_with_special_tokens = tf.cond(
        tf.less(tf.shape(tokenized_with_special_tokens)[1], BERT_MAX_SEQ_LENGTH),
        true_fn=lambda: dynamic_padding(tokenized_with_special_tokens, BERT_MAX_SEQ_LENGTH, mobilebert_tokenizer.pad_token_id),
        false_fn=lambda: tokenized_with_special_tokens)  

    tokenized_with_special_tokens = tf.cast(tokenized_with_special_tokens, dtype=tf.int32)

    # Random masking for the BERT MLM
    masked, labels = utils.get_masked_input_and_labels(
        tokenized_with_special_tokens,
        mobilebert_vocab_lookup_table,
        mobilebert_special_ids_mask_table,
        tf.constant(mobilebert_tokenizer.mask_token_id, dtype=tf.int32))

    # Squeeze out the first dimension
    masked = tf.squeeze(masked)
    labels = tf.squeeze(labels)

    # Manually settting the shape here so that TensorFlow graph
    # could know the sizes in advnace
    masked.set_shape(BERT_MAX_SEQ_LENGTH)
    labels.set_shape(BERT_MAX_SEQ_LENGTH)
    
    return masked, labels

def preprocess_for_train(train_dataset):
    return (
        # Filter out empty strings
        train_dataset.filter(lambda x: tf.strings.length(x['snippets']) > 0)
        # Tokenize each samples using MobileBERT tokenizer
        .map(tokenize_and_mask)
        # Shuffle
        .shuffle(BUFFER_SIZE)
        # Repeat to make each client train multiple epochs
        .repeat(count=CLIENT_EPOCHS_PER_ROUND)
        # Form minibatches
        # Use drop_remainder=True to force the batch size to be exactly BATCH_SIZE
        # and make the shape **exactly** (BATCH_SIZE, SEQ_LENGTH)
        .batch(BATCH_SIZE, drop_remainder=True))
    
def preprocess_for_test(test_dataset):
    return (
        # Filter out empty strings
        test_dataset.filter(lambda x: tf.strings.length(x['snippets']) > 0)
        # Tokenize each samples using MobileBERT tokenizer
        .map(tokenize_and_mask)
        # Shuffle
        .shuffle(BUFFER_SIZE)
        # Form minibatches
        # Use drop_remainder=True to force the batch size to be exactly TEST_BATCH_SIZE
        # and make the shape **exactly** (TEST_BATCH_SIZE, SEQ_LENGTH)
        .batch(TEST_BATCH_SIZE, drop_remainder=True))

In [11]:
train_client_data = train_client_data.preprocess(preprocess_fn=preprocess_for_train)
test_client_data = test_client_data.preprocess(preprocess_fn=preprocess_for_test)





Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.






In [12]:
# Create a test client dataset, just to get the element_spec info
example_dataset = train_client_data.create_tf_dataset_for_client('THE_TRAGEDY_OF_KING_LEAR_KING')
print(example_dataset.element_spec)





(TensorSpec(shape=(8, 128), dtype=tf.int32, name=None), TensorSpec(shape=(8, 128), dtype=tf.int32, name=None))


In [13]:
# Did the random masking go well?
for example_input, example_labels in example_dataset.take(3):
    print(example_input)
    print(example_labels)

    print(mobilebert_tokenizer.batch_decode(tf.squeeze(example_input).numpy()))

tf.Tensor(
[[  101 14383  1997 ...     0     0     0]
 [  101 21658  1010 ...     0     0     0]
 [  101  2065  2017 ...     0     0     0]
 ...
 [  101  2588  3067 ...     0     0     0]
 [  101 10590  1997 ...     0     0     0]
 [  101  1996  7909 ...  2000  2031   102]], shape=(8, 128), dtype=int32)
tf.Tensor(
[[-100 -100 -100 ... -100 -100 -100]
 [-100 -100 -100 ... -100 -100 -100]
 [-100 -100 -100 ... -100 -100 -100]
 ...
 [-100 -100 -100 ... -100 -100 -100]
 [-100 -100 -100 ... -100 -100 -100]
 [-100 -100 -100 ... -100 -100 -100]], shape=(8, 128), dtype=int32)
["[CLS] dom of navarre, my soul's earth [MASK] s god and body's f'ring patron'- [ reads [MASK]'so it is'- [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

## Model

In [14]:
mobilebert_model = transformers.TFMobileBertModel.from_pretrained(
    'google/mobilebert-uncased', cache_dir='./transformers_cache')

Some layers from the model checkpoint at google/mobilebert-uncased were not used when initializing TFMobileBertModel: ['predictions___cls', 'seq_relationship___cls']
- This IS expected if you are initializing TFMobileBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFMobileBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
All the layers of TFMobileBertModel were initialized from the model checkpoint at google/mobilebert-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMobileBertModel for predictions without further training.


In [15]:
# Due to the limitations with Keras subclasses, we can only use the main layer part from pretrained models
# and add output heads by ourselves
mobilebert_keras_converted = utils.convert_huggingface_mlm_to_keras(
    huggingface_model=mobilebert_model,
    max_seq_length=BERT_MAX_SEQ_LENGTH,
    batch_size=BATCH_SIZE)

In [16]:
# Use lists of NumPy arrays to backup pretained weights
mobilebert_pretrained_trainable_weights = []
mobilebert_pretrained_non_trainable_weights = []

for w in mobilebert_keras_converted.trainable_weights:
    mobilebert_pretrained_trainable_weights.append(w.numpy())

for w in mobilebert_keras_converted.non_trainable_weights:
    mobilebert_pretrained_non_trainable_weights.append(w.numpy())

In [17]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""

    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    model_wrapped = utils.KerasModelWrapper(
        tf.keras.models.clone_model(mobilebert_keras_converted), example_dataset.element_spec, loss)

    return model_wrapped

## Training

### Training setups

In [18]:
def server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=SERVER_LEARNING_RATE)

In [19]:
def client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=CLIENT_LEARNING_RATE)

In [20]:
iterative_process = fedavg.build_federated_averaging_process(
    model_fn=tff_model_fn,
    model_input_spec=example_dataset.element_spec,
    initial_trainable_weights=mobilebert_pretrained_trainable_weights,
    initial_non_trainable_weights=mobilebert_pretrained_non_trainable_weights,
    server_optimizer_fn=server_optimizer_fn, 
    client_optimizer_fn=client_optimizer_fn)

server_state = iterative_process.initialize()

Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


In [21]:
model_final = tff_model_fn()
metric_test = tf.keras.metrics.SparseCategoricalCrossentropy(from_logits=True, name='test_accuracy')

In [None]:
for round_num in range(TOTAL_ROUNDS):

    # Training clients selection
    print("Choosing clients to use for training...")

    sampled_clients = np.random.choice(
        train_client_data.client_ids,
        size=TRAIN_CLIENTS_PER_ROUND,
        replace=False)

    sampled_train_data = [
        train_client_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

    print("Training clients selection complete.")

    # FedAvg
    print(f'Round {round_num} start!')

    server_state, train_metrics = iterative_process.next(server_state, sampled_train_data)

    print(f'Round {round_num} training loss: {train_metrics}')

    # Evaluation
    if round_num % ROUNDS_PER_EVAL == 0:
        model_final.from_weights(server_state.model_weights)

        # Test dataset generation for this round
        print("Sampling clients to use for testing...")

        sampled_test_clients = np.random.choice(
            test_client_data.client_ids,
            size=TEST_CLIENTS_PER_ROUND,
            replace=False)

        sampled_test_data = [
            test_client_data.create_tf_dataset_for_client(client)
            for client in sampled_test_clients
        ]

        sampled_test_data_merged = sampled_test_data[0]

        if len(sampled_test_data) > 1:
            for client_test in range(1, len(sampled_test_data)):
                sampled_test_data_merged.concatenate(sampled_test_data[1])

        print("Test clients selected.")

        metric_validation = utils.keras_evaluate(model_final.keras_model, sampled_test_data_merged, metric_test)

        print(f'Round {round_num} validation metric: {metric_validation}')

Choosing clients to use for training...








Training clients selection complete.
Round 0 start!
Client 9650 : updated the model with server message.
Client 9650 : training start.
Client 21408 : updated the model with server message.
Client 21408 : training start.
Client 9650 : 8 processed
Client 21408 : 8 processed
Client 9650 : 16 processed
Client 21408 : 16 processed
Client 21408 : 24 processed
Client 21408 : 32 processed
Client 21408 : 40 processed
Client 21408 : 48 processed
Client 21408 : 56 processed
Client 21408 : 64 processed
Client 21408 : 72 processed
Client 21408 : 80 processed
Client 21408 : 88 processed
Client 21408 : 96 processed
Client 21408 : 104 processed
Client 21408 : 112 processed
Client 21408 : 120 processed
