# Fine-Tuning MobileBERT with Federated Averaging

In [None]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [None]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    drive.mount('/content/drive')
    
    # If there's a package I need to install separately, do it here
    !pip install pyro-ppl

    # cd to the appropriate working directory under my Google Drive
    %cd 'drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

# IPython reloading magic
%load_ext autoreload
%autoreload 2

# Random seeds
# Based on https://pytorch.org/docs/stable/notes/randomness.html
random_seed = 692

## Import packages

In [None]:
import sys
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import transformers

import nest_asyncio
nest_asyncio.apply()

import simple_fedavg_tf
import simple_fedavg_tff

# Random seed settings
np.random.seed(random_seed)
tf.random.set_seed(random_seed)

# Test the TFF is working:
tff.federated_computation(lambda: 'Hello, World!')()

In [None]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

## Experiment Settings

In [None]:
total_rounds = 256 # Number of total training rounds
rounds_per_eval = 1 # How often to evaluate
train_clients_per_round = 2 # How many clients to sample per round.
client_epochs_per_round = 1 # Number of epochs in the client to take per round.
batch_size = 20 # Batch size used on the client.
test_batch_size = 100 # Minibatch size of test data.

# Optimizer configuration
server_learning_rate = 1.0 # Server learning rate.
fclient_learning_rate = 0.1 # Client learning rate

## Dataset

In [None]:
train_data, test_data = tff.simulation.datasets.shakespeare.load_data()

## Model

In [None]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""
    keras_model = transformers.TFMobileBertForMaskedLM.from_pretrained(
        'google/mobilebert-uncased', cache_dir='./transformers_cache')

    loss = tf.keras.losses.SparseCategoricalCrossentropy()

    return simple_fedavg_tf.KerasModelWrapper(keras_model, test_data.element_spec, loss)

model = tff_model_fn()

## Training

In [None]:
iterative_process = simple_fedavg_tff.build_federated_averaging_process(
    tff_model_fn, server_optimizer_fn, client_optimizer_fn)

server_state = iterative_process.initialize()

metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')

In [None]:
for round_num in range(total_rounds):
    sampled_clients = np.random.choice(
        train_data.client_ids,
        size=train_clients_per_round,
        replace=False)

    sampled_train_data = [
        train_data.create_tf_dataset_for_client(client)
        for client in sampled_clients
    ]

    server_state, train_metrics = iterative_process.next(server_state, sampled_train_data)

    print(f'Round {round_num} training loss: {train_metrics}')

    if round_num % rounds_per_eval == 0:
        model.from_weights(server_state.model_weights)

        accuracy = simple_fedavg_tf.keras_evaluate(model.keras_model, test_data, metric)

        print(f'Round {round_num} validation accuracy: {accuracy * 100.0}')