# Fine-Tuning MobileBERT with Federated Averaging

In [1]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [2]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    drive.mount('/content/drive')
    
    # If there's a package I need to install separately, do it here
    !pip install pyro-ppl

    # cd to the appropriate working directory under my Google Drive
    %cd 'drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

# IPython reloading magic
%load_ext autoreload
%autoreload 2

# Random seeds
# Based on https://pytorch.org/docs/stable/notes/randomness.html
random_seed = 692

In [3]:
# Install required packages
!pip install tensorflow-federated-nightly==0.16.1.dev20201021 transformers==3.4.0

Collecting tensorflow-model-optimization~=0.4.0 (from tensorflow-federated-nightly==0.16.1.dev20201021)
  Using cached https://files.pythonhosted.org/packages/1a/cc/4b0831f492396f03a4563cc749ad94cbf7af986aaa7a7d89e5979029ce5c/tensorflow_model_optimization-0.4.1-py2.py3-none-any.whl
Collecting attrs~=19.3.0 (from tensorflow-federated-nightly==0.16.1.dev20201021)
  Using cached https://files.pythonhosted.org/packages/a2/db/4313ab3be961f7a763066401fb77f7748373b6094076ae2bda2806988af6/attrs-19.3.0-py2.py3-none-any.whl
Collecting numpy~=1.18.4 (from tensorflow-federated-nightly==0.16.1.dev20201021)
[?25l  Downloading https://files.pythonhosted.org/packages/b3/a9/b1bc4c935ed063766bce7d3e8c7b20bd52e515ff1c732b02caacf7918e5a/numpy-1.18.5-cp36-cp36m-manylinux1_x86_64.whl (20.1MB)
[K     |█████████████████▋              | 11.1MB 3.3MB/s eta 0:00:03^C  |█                               | 614kB 1.9MB/s eta 0:00:11     |██▌                             | 1.5MB 1.9MB/s eta 0:00:10     |████████▉    

## Import packages

In [4]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import transformers

import nest_asyncio
nest_asyncio.apply()

import simple_fedavg_tf
import simple_fedavg_tff

# Random seed settings
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

 The versions of TensorFlow you are currently using is 2.4.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


b'Hello, World!'

In [5]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

Python version: 3.6.9 |Intel Corporation| (default, Sep 11 2019, 16:40:08) 
[GCC 4.8.2 20140120 (Red Hat 4.8.2-15)]
NumPy version: 1.17.0
TensorFlow version: 2.4.0
TensorFlow Federated version: 0.16.1
Transformers version: 3.4.0


## Experiment Settings

In [6]:
total_rounds = 256 # Number of total training rounds
rounds_per_eval = 1 # How often to evaluate
train_clients_per_round = 2 # How many clients to sample per round.
client_epochs_per_round = 1 # Number of epochs in the client to take per round.
batch_size = 20 # Batch size used on the client.
test_batch_size = 100 # Minibatch size of test data.

# Optimizer configuration
server_learning_rate = 1.0 # Server learning rate.
fclient_learning_rate = 0.1 # Client learning rate

## Dataset

In [7]:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data(cache_dir='./tff_shakespeare_cache')

In [8]:
mobilebert_tokenizer = transformers.MobileBertTokenizer.from_pretrained(
    'google/mobilebert-uncased', cache_dir='./transformers_cache')

### Preprocessing

In [9]:
# Codes based on the tips from
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
# https://stackoverflow.com/questions/61555097/mapping-text-data-through-huggingface-tokenizer
def tokenize_snippets(x):
    tokenized = mobilebert_tokenizer.encode(
        tf.compat.as_str(x.numpy()), add_special_tokens=True, return_tensors='tf')
    
    return tokenized

def tf_tokenize(x):
    tokenized = tf.py_function(func=tokenize_snippets, inp = [x['snippets']], Tout=[tf.int32])

    return tokenized

def preprocess(dataset):
    return (
        # Map ASCII chars to int64 indexes using the vocab
        dataset.map(tf_tokenize))

In [10]:
raw_example_dataset = train_data.create_tf_dataset_for_client('THE_TRAGEDY_OF_KING_LEAR_KING')
example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)

(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None),)


## Model

In [11]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = transformers.TFMobileBertForMaskedLM.from_pretrained(
        'google/mobilebert-uncased', cache_dir='./transformers_cache')

    loss = tf.keras.losses.SparseCategoricalCrossentropy()

    return simple_fedavg_tf.KerasModelWrapper(keras_model, example_dataset.element_spec, loss)

model = tff_model_fn()

Some layers from the model checkpoint at google/mobilebert-uncased were not used when initializing TFMobileBertForMaskedLM: ['predictions___cls', 'seq_relationship___cls']
- This IS expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFMobileBertForMaskedLM were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['mlm___cls']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Training

In [None]:
iterative_process = simple_fedavg_tff.build_federated_averaging_process(
    tff_model_fn, server_optimizer_fn, client_optimizer_fn)

server_state = iterative_process.initialize()

metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
for round_num in range(total_rounds):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=train_clients_per_round,
        replace=False)

    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

    server_state, train_metrics = iterative_process.next(server_state, sampled_train_data)

    print(f'Round {round_num} training loss: {train_metrics}')

    if round_num % rounds_per_eval == 0:
        model.from_weights(server_state.model_weights)

        accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric)

        print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')