# Fine-Tuning MobileBERT with Federated Averaging

In [1]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [2]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    # If there's a package I need to install separately, do it here
    #!pip install tensorflow-federated-nightly==0.16.1.dev20201021 tf-nightly-gpu==2.4.0.dev20201021 transformers==3.4.0
    #!pip uninstall tf-nightly

    # Mount Google Drive root directory
    drive.mount('/content/drive')

    # cd to the appropriate working directory under my Google Drive
    %cd 'drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

# IPython reloading magic
%load_ext autoreload
%autoreload 2

Mounted at /content/drive
/content/drive/My Drive/Colab Notebooks/BERTerated
bert_fedavg_main.ipynb	__pycache__  simple_fedavg_test.py  simple_fedavg_tf.py
LICENSE			README.md    simple_fedavg_tff.py   transformers_cache


In [3]:
!pip install tensorflow-federated==0.17.0 transformers==3.4.0

Collecting tensorflow-federated==0.17.0
[?25l  Downloading https://files.pythonhosted.org/packages/5c/54/900d99d3cff21b6a570281b51f4878a745c0eece7732bb7fc26eee61ef57/tensorflow_federated-0.17.0-py2.py3-none-any.whl (517kB)
[K     |████████████████████████████████| 522kB 5.4MB/s 
[?25hCollecting transformers==3.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/2c/4e/4f1ede0fd7a36278844a277f8d53c21f88f37f3754abf76a5d6224f76d4a/transformers-3.4.0-py3-none-any.whl (1.3MB)
[K     |████████████████████████████████| 1.3MB 13.5MB/s 
[?25hCollecting attrs~=19.3.0
  Downloading https://files.pythonhosted.org/packages/a2/db/4313ab3be961f7a763066401fb77f7748373b6094076ae2bda2806988af6/attrs-19.3.0-py2.py3-none-any.whl
Collecting tensorflow-addons~=0.11.1
[?25l  Downloading https://files.pythonhosted.org/packages/b3/f8/d6fca180c123f2851035c4493690662ebdad0849a9059d56035434bff5c9/tensorflow_addons-0.11.2-cp36-cp36m-manylinux2010_x86_64.whl (1.1MB)
[K     |███████████████████████

## Import packages

In [1]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import transformers

import nest_asyncio
nest_asyncio.apply()

import simple_fedavg_tf
import simple_fedavg_tff

# Random seed settings
random_seed = 692
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

# Tensorflow GPU
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

 The versions of TensorFlow you are currently using is 2.4.0 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


Num GPUs Available:  0


b'Hello, World!'

In [2]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

Python version: 3.6.9 |Intel Corporation| (default, Sep 11 2019, 16:40:08) 
[GCC 4.8.2 20140120 (Red Hat 4.8.2-15)]
NumPy version: 1.17.0
TensorFlow version: 2.4.0
TensorFlow Federated version: 0.16.1
Transformers version: 3.4.0


## Experiment Settings

In [3]:
TOTAL_ROUNDS = 256 # Number of total training rounds
ROUNDS_PER_EVAL = 1 # How often to evaluate
TRAIN_CLIENTS_PER_ROUND = 2 # How many clients to sample per round.
CLIENT_EPOCHS_PER_ROUND = 1 # Number of epochs in the client to take per round.
BATCH_SIZE = 20 # Batch size used on the client.
BUFFER_SIZE = 1000  # For dataset shuffling
TEST_BATCH_SIZE = 100 # Minibatch size of test data.

# Optimizer configuration
SERVER_LEARNING_RATE = 1.0 # Server learning rate.
CLIENT_LEARNING_RATE = 0.1 # Client learning rate

## Dataset

In [4]:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data(cache_dir='./tff_shakespeare_cache')

In [5]:
mobilebert_tokenizer = transformers.MobileBertTokenizer.from_pretrained(
    'google/mobilebert-uncased', cache_dir='./transformers_cache')

### Preprocessing

In [6]:
# Codes based on the tips from
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
# https://stackoverflow.com/questions/61555097/mapping-text-data-through-huggingface-tokenizer
# https://stackoverflow.com/questions/61022109/how-to-return-a-dictionary-of-tensors-from-tf-py-function
def tokenize_snippets(x):
    tokenized = mobilebert_tokenizer.encode(
        tf.compat.as_str(x.numpy()), add_special_tokens=True, padding='max_length', return_tensors='tf')
    
    return tokenized

def tf_tokenize(x):
    tokenized = tf.py_function(
        func=tokenize_snippets, inp=[x['snippets']], Tout=[tf.int32])

    return tokenized

def preprocess(dataset):
    return (
        # Tokenize each samples using MobileBERT tokenizer
        dataset.map(tf_tokenize)
        # Shuffle and form minibatches
        .shuffle(BUFFER_SIZE).batch(BATCH_SIZE))

In [7]:
raw_example_dataset = train_data.create_tf_dataset_for_client('THE_TRAGEDY_OF_KING_LEAR_KING')
example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)

(TensorSpec(shape=<unknown>, dtype=tf.int32, name=None),)


## Model

In [11]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = transformers.TFMobileBertForMaskedLM.from_pretrained(
        'google/mobilebert-uncased', cache_dir='./transformers_cache')

    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    return simple_fedavg_tf.KerasModelWrapper(keras_model, example_dataset.element_spec, loss)

model = tff_model_fn()

Some layers from the model checkpoint at google/mobilebert-uncased were not used when initializing TFMobileBertForMaskedLM: ['seq_relationship___cls', 'predictions___cls']
- This IS expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFMobileBertForMaskedLM were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['mlm___cls']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## Training

### Training setups

In [None]:
def server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=SERVER_LEARNING_RATE)

In [None]:
def client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=CLIENT_LEARNING_RATE)

In [12]:
iterative_process = simple_fedavg_tff.build_federated_averaging_process(
    tff_model_fn, server_optimizer_fn, client_optimizer_fn)

server_state = iterative_process.initialize()

metric = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, name='test_accuracy')

Some layers from the model checkpoint at google/mobilebert-uncased were not used when initializing TFMobileBertForMaskedLM: ['seq_relationship___cls', 'predictions___cls']
- This IS expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFMobileBertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some layers of TFMobileBertForMaskedLM were not initialized from the model checkpoint at google/mobilebert-uncased and are newly initialized: ['mlm___cls']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some layers from the model checkpoint at google/mobilebert-uncased were not used

Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


Instructions for updating:
Use `tf.data.Iterator.get_next_as_optional()` instead.


TypeError: ignored

In [None]:
for round_num in range(TOTAL_ROUNDS):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=TRAIN_CLIENTS_PER_ROUND,
        replace=False)

    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

    server_state, train_metrics = iterative_process.next(server_state, sampled_train_data)

    print(f'Round {round_num} training loss: {train_metrics}')

    if round_num % rounds_per_eval == 0:
        model.from_weights(server_state.model_weights)

        accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric)

        print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')