# Fine-Tuning MobileBERT with Federated Averaging

In [None]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [None]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    # If there's a package I need to install separately, do it here
    #!pip install tensorflow-federated-nightly==0.16.1.dev20201021 tf-nightly-gpu==2.4.0.dev20201021 transformers==3.4.0
    #!pip uninstall tf-nightly

    # Mount Google Drive root directory
    drive.mount('/content/drive')

    # cd to the appropriate working directory under my Google Drive
    %cd 'drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

# IPython reloading magic
%load_ext autoreload
%autoreload 2

In [None]:
!pip install tensorflow-federated==0.17.0 transformers==3.4.0

## Import packages

In [None]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import transformers

import nest_asyncio
nest_asyncio.apply()

import simple_fedavg_tf
import simple_fedavg_tff
import utils

# Random seed settings
random_seed = 692
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

# Tensorflow GPU
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

In [None]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

## Experiment Settings

In [None]:
TOTAL_ROUNDS = 256 # Number of total training rounds
ROUNDS_PER_EVAL = 1 # How often to evaluate
TRAIN_CLIENTS_PER_ROUND = 2 # How many clients to sample per round.
CLIENT_EPOCHS_PER_ROUND = 1 # Number of epochs in the client to take per round.
BATCH_SIZE = 20 # Batch size used on the client.
BUFFER_SIZE = 1000  # For dataset shuffling
TEST_BATCH_SIZE = 100 # Minibatch size of test data.

# Optimizer configuration
SERVER_LEARNING_RATE = 1.0 # Server learning rate.
CLIENT_LEARNING_RATE = 0.1 # Client learning rate

## Dataset

In [None]:
# Originally tff.simulation.datasets.stackoverflow.load_data(cache_dir='./tff_cache')
!curl -o ./tff_cache/datasets/stackoverflow.tar.bz2 https://storage.googleapis.com/tff-datasets-public/stackoverflow.tar.bz2
!tar xvfj ./tff_cache/datasets/stackoverflow.tar.bz2 -C ./tff_cache/datasets

In [None]:
import os
from tensorflow_federated.python.simulation import hdf5_client_data

#train_client_data = hdf5_client_data.HDF5ClientData(
#    os.path.join('.', 'tff_cache', 'datasets', 'stackoverflow_train.h5'))
held_out_client_data = hdf5_client_data.HDF5ClientData(
    os.path.join('.', 'tff_cache', 'datasets', 'stackoverflow_held_out.h5'))
#test_client_data = hdf5_client_data.HDF5ClientData(
#    os.path.join('.', 'tff_cache', 'datasets', 'stackoverflow_test.h5'))

In [None]:
test = held_out_client_data.create_tf_dataset_for_client('00045530')

In [None]:
for item in test.take(10):
    print(item['type'])
    print(item['tokens'])

In [None]:
mobilebert_tokenizer = transformers.MobileBertTokenizer.from_pretrained(
    'google/mobilebert-uncased', cache_dir='./transformers_cache')

### Preprocessing

In [None]:
# Codes based on the tips from
# https://www.tensorflow.org/federated/tutorials/federated_learning_for_text_generation
# https://stackoverflow.com/questions/61555097/mapping-text-data-through-huggingface-tokenizer
# https://stackoverflow.com/questions/61022109/how-to-return-a-dictionary-of-tensors-from-tf-py-function
def tokenize_and_mask(x):
    # 'tokenized' is a PyTorch tensor because Transformers collator only accepts PyTorch tensors
    tokenized = mobilebert_tokenizer.encode(
        tf.compat.as_str(x.numpy()),
        add_special_tokens=True, padding='max_length', max_length=512,
        return_tensors='tf')
    
    masked, labels = utils.get_masked_input_and_labels(tokenized, mobilebert_tokenizer)
    
    return masked, labels

def tf_tokenize(x):
    masked, labels = tf.py_function(
        func=tokenize_and_mask, inp=[x['tokens']], Tout=[tf.int32, tf.int32])

    return tf.squeeze(masked), tf.squeeze(labels)

def preprocess(dataset):
    return (
        # Tokenize each samples using MobileBERT tokenizer
        dataset.map(tf_tokenize)
        # Shuffle and form minibatches
        .shuffle(BUFFER_SIZE).batch(BATCH_SIZE))

In [None]:
raw_example_dataset = train_data.create_tf_dataset_for_client('THE_TRAGEDY_OF_KING_LEAR_KING')
example_dataset = preprocess(raw_example_dataset)
print(example_dataset.element_spec)

## Model

In [None]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = transformers.TFMobileBertForMaskedLM.from_pretrained(
        'google/mobilebert-uncased', cache_dir='./transformers_cache')

    loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    return simple_fedavg_tf.KerasModelWrapper(keras_model, example_dataset.element_spec, loss)

model = tff_model_fn()

## Training

### Training setups

In [None]:
def server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=SERVER_LEARNING_RATE)

In [None]:
def client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=CLIENT_LEARNING_RATE)

In [None]:
iterative_process = simple_fedavg_tff.build_federated_averaging_process(
    tff_model_fn, server_optimizer_fn, client_optimizer_fn)

server_state = iterative_process.initialize()

metric = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, name='test_accuracy')

In [None]:
for round_num in range(TOTAL_ROUNDS):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=TRAIN_CLIENTS_PER_ROUND,
        replace=False)

    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

    server_state, train_metrics = iterative_process.next(server_state, sampled_train_data)

    print(f'Round {round_num} training loss: {train_metrics}')

    if round_num % rounds_per_eval == 0:
        model.from_weights(server_state.model_weights)

        accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric)

        print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')