# Further Pre-training MobileBERT MLM with Federated Averaging (Shakepeare)

In [1]:
# Copyright 2020, The TensorFlow Federated Authors.
# Copyright 2020, Ronald Seoh
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

## Google Colab settings

In [3]:
# Use Google Colab
use_colab = True

# Is this notebook running on Colab?
# If so, then google.colab package (github.com/googlecolab/colabtools)
# should be available in this environment

# Previous version used importlib, but we could do the same thing with
# just attempting to import google.colab
try:
    from google.colab import drive
    colab_available = True
except:
    colab_available = False

if use_colab and colab_available:
    # Mount Google Drive root directory
    drive.mount('/content/drive')

    # cd to the appropriate working directory under my Google Drive
    %cd '/content/drive/My Drive/Colab Notebooks/BERTerated'
    
    # List the directory contents
    !ls

### CUDA Multi GPU

In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [4]:
# IPython reloading magic
%load_ext autoreload
%autoreload 2

In [5]:
# Install required packages
#!pip install -r requirements.txt

## Import packages

In [6]:
import os
import sys
import random
import datetime
import json
import pathlib
import itertools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
import tensorflow_text as tf_text
import transformers

import nest_asyncio
nest_asyncio.apply()

import fedavg
import fedavg_client
import datasets
import utils


# Random seed settings
random_seed = 692
random.seed(random_seed) # Python
np.random.seed(random_seed) # NumPy
tf.random.set_seed(random_seed) # TensorFlow

# Tensorflow GPU
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))

# Test if TFF is working
tff.federated_computation(lambda: 'Hello, World!')()

TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.


Num GPUs Available:  1


b'Hello, World!'

In [7]:
# Print version information
print("Python version: " + sys.version)
print("NumPy version: " + np.__version__)
print("TensorFlow version: " + tf.__version__)
print("TensorFlow Federated version: " + tff.__version__)
print("Transformers version: " + transformers.__version__)

Python version: 3.6.9 (default, Oct  8 2020, 12:12:24) 
[GCC 8.4.0]
NumPy version: 1.18.4
TensorFlow version: 2.5.0-dev20201121
TensorFlow Federated version: 0.17.0
Transformers version: 3.4.0


In [8]:
!nvidia-smi

Mon Nov 23 16:48:19 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.57       Driver Version: 450.57       CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  Off  | 00000000:03:00.0  On |                  N/A |
|  0%   26C    P2    46W / 250W |  10393MiB / 11018MiB |      3%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  GeForce RTX 208...  Off  | 00000000:04:00.0 Off |                  N/A |
|  0%   41C    P2   129W / 250W |  10537MiB / 11019MiB |     62%      Defaul

## Experiment Settings

In [9]:
EXPERIMENT_CONFIG = {}

EXPERIMENT_CONFIG['HUGGINGFACE_MODEL_NAME'] = 'google/mobilebert-uncased'
EXPERIMENT_CONFIG['HUGGINGFACE_CACHE_DIR'] = os.path.join('.', 'transformers_cache')

EXPERIMENT_CONFIG['TOTAL_ROUNDS'] = 100 # Number of total training rounds
EXPERIMENT_CONFIG['ROUNDS_PER_EVAL'] = 10 # How often to evaluate

EXPERIMENT_CONFIG['TRAIN_CLIENTS_PER_ROUND'] = 10 # How many clients to sample per round.
EXPERIMENT_CONFIG['CLIENT_EPOCHS_PER_ROUND'] = 3

EXPERIMENT_CONFIG['BATCH_SIZE'] = 8 # Batch size used on the client.
EXPERIMENT_CONFIG['TEST_BATCH_SIZE'] = 16 # Minibatch size of test data.

# Maximum length of input token sequence for BERT.
EXPERIMENT_CONFIG['BERT_MAX_SEQ_LENGTH'] = 128

# Optimizer configuration
EXPERIMENT_CONFIG['SERVER_LEARNING_RATE'] = 1.0 # Server learning rate.
EXPERIMENT_CONFIG['CLIENT_LEARNING_RATE'] = 5e-5 # Client learning rate

# Client dataset setting
EXPERIMENT_CONFIG['TRAIN_NUM_CLIENT_LIMIT'] = -1
EXPERIMENT_CONFIG['TEST_NUM_CLIENT_LIMIT'] = -1

# Path to save trained weights and logs
EXPERIMENT_CONFIG['RESULTS_DIRECTORY'] = os.path.join(
    '.', 'results',
    'mobilebert_mlm_shakespeare_fedavg',
    datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
)

EXPERIMENT_CONFIG['RESULTS_LOG'] = os.path.join(EXPERIMENT_CONFIG['RESULTS_DIRECTORY'], "logs")
EXPERIMENT_CONFIG['RESULTS_MODEL'] = os.path.join(EXPERIMENT_CONFIG['RESULTS_DIRECTORY'], "model")
EXPERIMENT_CONFIG['RESULTS_CONFIG'] = os.path.join(EXPERIMENT_CONFIG['RESULTS_DIRECTORY'], "config")

In [10]:
# Dump all the configuration into a json file
pathlib.Path(EXPERIMENT_CONFIG['RESULTS_CONFIG']).mkdir(parents=True, exist_ok=True)

with open(os.path.join(EXPERIMENT_CONFIG['RESULTS_CONFIG'], "config.json"), 'w') as config_file:
    json.dump(EXPERIMENT_CONFIG, config_file, indent=6)

In [11]:
tff.backends.native.set_local_execution_context(
    num_clients=EXPERIMENT_CONFIG['TRAIN_CLIENTS_PER_ROUND'],
    max_fanout=10, clients_per_thread=1
)

## Dataset

### Dataset loader

In [12]:
train_client_data, test_client_data = tff.simulation.datasets.shakespeare.load_data(cache_dir='./tff_cache')

### Tokenizer

In [13]:
bert_tokenizer = transformers.AutoTokenizer.from_pretrained(
    EXPERIMENT_CONFIG['HUGGINGFACE_MODEL_NAME'], cache_dir=EXPERIMENT_CONFIG['HUGGINGFACE_CACHE_DIR'])

In [14]:
# Imitate transformers tokenizer with TF.Text Tokenizer
tokenizer_tf_text, vocab_lookup_table, special_ids_mask_table = datasets.preprocessing_for_bert.convert_huggingface_tokenizer(bert_tokenizer)

### Preprocessing

In [15]:
def check_empty_snippet(x):
    return tf.strings.length(x['snippets']) > 0

def tokenizer_and_mask_wrapped(x):

    masked, labels = datasets.preprocessing_for_bert.tokenize_and_mask(tf.reshape(x['snippets'], shape=[1]),
                                                                       max_seq_length=EXPERIMENT_CONFIG['BERT_MAX_SEQ_LENGTH'],
                                                                       bert_tokenizer_tf_text=tokenizer_tf_text,
                                                                       vocab_lookup_table=vocab_lookup_table,
                                                                       special_ids_mask_table=special_ids_mask_table,
                                                                       cls_token_id=bert_tokenizer.cls_token_id,
                                                                       sep_token_id=bert_tokenizer.sep_token_id,
                                                                       pad_token_id=bert_tokenizer.pad_token_id,
                                                                       mask_token_id=bert_tokenizer.mask_token_id)

    return (masked, labels)

def preprocess_for_train(train_dataset):
    return (
        train_dataset
        # Tokenize each samples using MobileBERT tokenizer
        .map(tokenizer_and_mask_wrapped, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
        # Shuffle
        .shuffle(100000)
        # Form minibatches
        # Use drop_remainder=True to force the batch size to be exactly BATCH_SIZE
        # and make the shape **exactly** (BATCH_SIZE, SEQ_LENGTH)
        .batch(EXPERIMENT_CONFIG['BATCH_SIZE'])#, drop_remainder=True)
        # Repeat to make each client train multiple epochs
        .repeat(count=EXPERIMENT_CONFIG['CLIENT_EPOCHS_PER_ROUND'])
    )
    
def preprocess_for_test(test_dataset):
    return (
        test_dataset
        # Tokenize each samples using MobileBERT tokenizer
        .map(tokenizer_and_mask_wrapped, num_parallel_calls=tf.data.experimental.AUTOTUNE, deterministic=False)
        # Shuffle
        .shuffle(100000)
        # Form minibatches
        # Use drop_remainder=True to force the batch size to be exactly TEST_BATCH_SIZE
        # and make the shape **exactly** (TEST_BATCH_SIZE, SEQ_LENGTH)
        .batch(EXPERIMENT_CONFIG['TEST_BATCH_SIZE'])#, drop_remainder=True)
    )   

### Training set

In [16]:
# Since the dataset is pretty large, we randomly select TRAIN_NUM_CLIENT_LIMIT number of clients.
all_train_client_ids = train_client_data.client_ids

random.shuffle(all_train_client_ids)

if EXPERIMENT_CONFIG['TRAIN_NUM_CLIENT_LIMIT'] > 0:
    selected_train_client_ids = all_train_client_ids[0:EXPERIMENT_CONFIG['TRAIN_NUM_CLIENT_LIMIT']]
else:
    selected_train_client_ids = all_train_client_ids

In [17]:
train_client_data = train_client_data.preprocess(preprocess_fn=lambda x: x.filter(check_empty_snippet))

In [18]:
train_client_data = train_client_data.preprocess(preprocess_fn=preprocess_for_train)

Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


Instructions for updating:
The TensorFlow Distributions library has moved to TensorFlow Probability (https://github.com/tensorflow/probability). You should update all references to use `tfp.distributions` instead of `tf.distributions`.


In [19]:
print(train_client_data.element_type_structure)

(TensorSpec(shape=(None, 128), dtype=tf.int32, name=None), TensorSpec(shape=(None, 128), dtype=tf.int32, name=None))


### Test set

In [20]:
test_client_data_all_merged = test_client_data.create_tf_dataset_for_client(
    test_client_data.client_ids[0]).filter(check_empty_snippet)

if len(test_client_data.client_ids) > 1:
    for i in range(1, len(test_client_data.client_ids)):
        test_client_data_all_merged = test_client_data_all_merged.concatenate(
            test_client_data.create_tf_dataset_for_client(test_client_data.client_ids[i]).filter(check_empty_snippet))

In [21]:
test_client_data_all_merged = preprocess_for_test(test_client_data_all_merged)

In [22]:
print(test_client_data_all_merged.element_spec)

(TensorSpec(shape=(None, 128), dtype=tf.int32, name=None), TensorSpec(shape=(None, 128), dtype=tf.int32, name=None))


## Model

In [23]:
bert_model = transformers.TFAutoModelForPreTraining.from_pretrained(
    EXPERIMENT_CONFIG['HUGGINGFACE_MODEL_NAME'], cache_dir=EXPERIMENT_CONFIG['HUGGINGFACE_CACHE_DIR'])

All model checkpoint layers were used when initializing TFMobileBertForPreTraining.

All the layers of TFMobileBertForPreTraining were initialized from the model checkpoint at google/mobilebert-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFMobileBertForPreTraining for predictions without further training.


In [24]:
print(bert_model.config)

MobileBertConfig {
  "_name_or_path": "google/mobilebert-uncased",
  "attention_probs_dropout_prob": 0.1,
  "classifier_activation": false,
  "embedding_size": 128,
  "hidden_act": "relu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 512,
  "initializer_range": 0.02,
  "intermediate_size": 512,
  "intra_bottleneck_size": 128,
  "key_query_shared_bottleneck": true,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "mobilebert",
  "normalization_type": "no_norm",
  "num_attention_heads": 4,
  "num_feedforward_networks": 4,
  "num_hidden_layers": 24,
  "pad_token_id": 0,
  "trigram_input": true,
  "true_hidden_size": 128,
  "type_vocab_size": 2,
  "use_bottleneck": true,
  "use_bottleneck_attention": false,
  "vocab_size": 30522
}



In [25]:
# Due to the limitations with Keras subclasses, we can only use the main layer part from pretrained models
# and add output heads by ourselves
bert_keras_converted = utils.convert_huggingface_mlm_to_keras(
    huggingface_model=bert_model,
    max_seq_length=EXPERIMENT_CONFIG['BERT_MAX_SEQ_LENGTH'],
)

Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


Instructions for updating:
The `validate_indices` argument has no effect. Indices are always validated on CPU and never validated on GPU.


In [26]:
# Use lists of NumPy arrays to backup pretained weights
bert_pretrained_trainable_weights = []
bert_pretrained_non_trainable_weights = []

for w in bert_keras_converted.trainable_weights:
    bert_pretrained_trainable_weights.append(w.numpy())

for w in bert_keras_converted.non_trainable_weights:
    bert_pretrained_non_trainable_weights.append(w.numpy())

In [27]:
def tff_model_fn():
    """Constructs a fully initialized model for use in federated averaging."""

    loss = utils.MaskedLMCrossEntropy()

    model_wrapped = utils.KerasModelWrapper(
        tf.keras.models.clone_model(bert_keras_converted),
        train_client_data.element_type_structure, loss)

    return model_wrapped

## Training

### Training setups

In [28]:
summary_writer = tf.summary.create_file_writer(EXPERIMENT_CONFIG['RESULTS_LOG'])

In [29]:
def server_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=EXPERIMENT_CONFIG['SERVER_LEARNING_RATE'])

def client_optimizer_fn():
    return tf.keras.optimizers.SGD(learning_rate=EXPERIMENT_CONFIG['CLIENT_LEARNING_RATE'])

In [30]:
iterative_process = fedavg.build_federated_averaging_process(
    model_fn=tff_model_fn,
    model_input_spec=train_client_data.element_type_structure,
    initial_trainable_weights=bert_pretrained_trainable_weights,
    initial_non_trainable_weights=bert_pretrained_non_trainable_weights,
    server_optimizer_fn=server_optimizer_fn, 
    client_optimizer_fn=client_optimizer_fn)

In [31]:
server_state = iterative_process.initialize()

In [32]:
metric_eval = utils.MaskedLMCrossEntropyMetric(name='test_accuracy')

In [33]:
model_final = tff_model_fn() # The one to store the final weights

### Training loop

In [34]:
with summary_writer.as_default():
    for round_num in range(1, EXPERIMENT_CONFIG['TOTAL_ROUNDS'] + 1):

        # Training clients selection
        print("Choosing clients to use for training...")

        sampled_clients = np.random.choice(
            all_train_client_ids,
            size=EXPERIMENT_CONFIG['TRAIN_CLIENTS_PER_ROUND'],
            replace=False)

        sampled_train_data = [
            train_client_data.create_tf_dataset_for_client(client)
            for client in sampled_clients
        ]

        print("Training clients selection complete.")

        # FedAvg
        print(f'Round {round_num} start!')

        server_state, train_loss = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num} training loss: {train_loss}')
        
        # Write down train_metrics to the log
        tf.summary.scalar('train_loss', train_loss, step=round_num)

        # Evaluation
        if round_num % EXPERIMENT_CONFIG['ROUNDS_PER_EVAL'] == 0:
            model_final.from_weights(server_state.model_weights)

            # Test dataset generation for this round
            print("Calculating validation metric:")

            validation_metric = utils.keras_evaluate(
                model_final.keras_model, test_client_data_all_merged, metric_eval)

            print(f'Round {round_num} validation metric: {validation_metric}')
            
            # Write down train_metrics to the log
            tf.summary.scalar('validation_metric', validation_metric, step=round_num)
            
        print()

Choosing clients to use for training...
Training clients selection complete.
Round 1 start!
Client 25530 : updated the model with server message.
Anonymous client 25530 : training start.
Anonymous client 25530 : batch 1 , 8 examples processed
Anonymous client 25530 : batch 2 , 16 examples processed
Anonymous client 25530 : batch 3 , 24 examples processed
Anonymous client 25530 : batch 4 , 32 examples processed
Anonymous client 25530 : batch 5 , 40 examples processed
Anonymous client 25530 : batch 6 , 48 examples processed
Anonymous client 25530 : batch 7 , 56 examples processed
Anonymous client 25530 : batch 8 , 60 examples processed
Anonymous client 25530 : batch 9 , 68 examples processed
Anonymous client 25530 : batch 10 , 76 examples processed
Anonymous client 25530 : batch 11 , 84 examples processed
Anonymous client 25530 : batch 12 , 92 examples processed
Anonymous client 25530 : batch 13 , 100 examples processed
Anonymous client 25530 : batch 14 , 108 examples processed
Anonymous

Client 1219 : updated the model with server message.
Anonymous client 1219 : training start.
Client 14916 : updated the model with server message.
Anonymous client 14916 : training start.
Client 1014 : updated the model with server message.
Anonymous client 1014 : training start.
Client 4615 : updated the model with server message.
Anonymous client 4615 : training start.
Client 22479 : updated the model with server message.
Client 20304 : updated the model with server message.
Anonymous client 22479 : training start.
Anonymous client 20304 : training start.
Client 21452 : updated the model with server message.
Anonymous client 21452 : training start.
Client 14833 : updated the model with server message.
Anonymous client 14833 : training start.
Anonymous client 26682 : batch 1 , 8 examples processed
Anonymous client 16586 : batch 1 , 8 examples processed
Anonymous client 4615 : batch 1 , 5 examples processed
Anonymous client 21452 : batch 1 , 8 examples processed
Anonymous client 20304 

Anonymous client 21452 : batch 26 , 205 examples processed
Anonymous client 14833 : batch 27 , 213 examples processed
Anonymous client 26682 : batch 27 , 213 examples processed
Anonymous client 21452 : batch 27 , 213 examples processed
Anonymous client 14833 : batch 28 , 221 examples processed
Anonymous client 26682 : batch 28 , 221 examples processed
Anonymous client 21452 : batch 28 , 218 examples processed
Anonymous client 14833 : batch 29 , 229 examples processed
Anonymous client 26682 : batch 29 , 229 examples processed
Anonymous client 21452 : batch 29 , 226 examples processed
Anonymous client 14833 : batch 30 , 237 examples processed
Anonymous client 26682 : batch 30 , 237 examples processed
Anonymous client 21452 : batch 30 , 234 examples processed
Anonymous client 26682 : batch 31 , 245 examples processed
Anonymous client 14833 : batch 31 , 245 examples processed
Anonymous client 26682 : batch 32 , 253 examples processed
Anonymous client 21452 : batch 31 , 242 examples process

Anonymous client 9550 : batch 3 , 23 examples processed
Anonymous client 10585 : batch 4 , 30 examples processed
Anonymous client 8447 : training finished. 18  examples processed, loss: 54.1986923
Anonymous client 12144 : training finished. 3  examples processed, loss: 10.1544542
Anonymous client 7057 : batch 4 , 32 examples processed
Anonymous client 25169 : training finished. 6  examples processed, loss: 45.8482971
Anonymous client 10997 : batch 4 , 32 examples processed
Anonymous client 19466 : batch 4 , 32 examples processed
Anonymous client 29800 : batch 4 , 26 examples processed
Anonymous client 9550 : batch 4 , 30 examples processed
Anonymous client 10585 : batch 5 , 38 examples processed
Anonymous client 7057 : batch 5 , 33 examples processed
Anonymous client 10997 : batch 5 , 40 examples processed
Anonymous client 19466 : batch 5 , 40 examples processed
Anonymous client 29800 : batch 5 , 34 examples processed
Anonymous client 9550 : batch 5 , 38 examples processed
Anonymous cl

Anonymous client 15002 : batch 5 , 26 examples processed
Anonymous client 22664 : batch 5 , 40 examples processed
Anonymous client 26598 : batch 5 , 40 examples processed
Anonymous client 16778 : batch 7 , 44 examples processed
Anonymous client 15002 : batch 6 , 27 examples processed
Anonymous client 22664 : batch 6 , 48 examples processed
Anonymous client 26598 : batch 6 , 48 examples processed
Anonymous client 16778 : batch 8 , 52 examples processed
Anonymous client 15002 : training finished. 27  examples processed, loss: 112.553345
Anonymous client 22664 : batch 7 , 56 examples processed
Anonymous client 26598 : batch 7 , 56 examples processed
Anonymous client 16778 : batch 9 , 54 examples processed
Anonymous client 22664 : batch 8 , 64 examples processed
Anonymous client 26598 : batch 8 , 64 examples processed
Anonymous client 16778 : training finished. 54  examples processed, loss: 218.091309
Anonymous client 22664 : batch 9 , 72 examples processed
Anonymous client 26598 : batch 9

Anonymous client 13308 : batch 4 , 32 examples processed
Anonymous client 5649 : batch 3 , 24 examples processed
Anonymous client 23536 : batch 3 , 24 examples processed
Anonymous client 12519 : batch 3 , 3 examples processed
Anonymous client 12412 : batch 3 , 15 examples processed
Anonymous client 5759 : batch 3 , 15 examples processed
Anonymous client 4569 : batch 3 , 24 examples processed
Anonymous client 23203 : training finished. 9  examples processed, loss: 30.8446922
Anonymous client 22734 : batch 4 , 18 examples processed
Anonymous client 22007 : training finished. 12  examples processed, loss: 50.7827072
Anonymous client 13308 : batch 5 , 40 examples processed
Anonymous client 5649 : batch 4 , 32 examples processed
Anonymous client 23536 : batch 4 , 32 examples processed
Anonymous client 4569 : batch 4 , 32 examples processed
Anonymous client 12519 : training finished. 3  examples processed, loss: 6.50894213
Anonymous client 22734 : batch 5 , 26 examples processed
Anonymous cl

Anonymous client 23536 : batch 78 , 622 examples processed
Anonymous client 23536 : batch 79 , 630 examples processed
Anonymous client 23536 : batch 80 , 638 examples processed
Anonymous client 23536 : batch 81 , 646 examples processed
Anonymous client 23536 : batch 82 , 654 examples processed
Anonymous client 23536 : batch 83 , 662 examples processed
Anonymous client 23536 : batch 84 , 670 examples processed
Anonymous client 23536 : batch 85 , 678 examples processed
Anonymous client 23536 : batch 86 , 686 examples processed
Anonymous client 23536 : batch 87 , 694 examples processed
Anonymous client 23536 : batch 88 , 702 examples processed
Anonymous client 23536 : batch 89 , 710 examples processed
Anonymous client 23536 : batch 90 , 718 examples processed
Anonymous client 23536 : batch 91 , 726 examples processed
Anonymous client 23536 : batch 92 , 734 examples processed
Anonymous client 23536 : batch 93 , 741 examples processed
Anonymous client 23536 : training finished. 741  example

Anonymous client 8739 : batch 2 , 16 examples processed
Anonymous client 28564 : batch 2 , 13 examples processed
Anonymous client 17092 : batch 2 , 2 examples processed
Anonymous client 17182 : batch 2 , 16 examples processed
Anonymous client 13548 : batch 2 , 4 examples processed
Anonymous client 22174 : batch 2 , 6 examples processed
Anonymous client 9510 : batch 3 , 24 examples processed
Anonymous client 10630 : batch 3 , 21 examples processed
Anonymous client 18284 : batch 3 , 17 examples processed
Anonymous client 6980 : batch 3 , 3 examples processed
Anonymous client 8739 : batch 3 , 24 examples processed
Anonymous client 17092 : batch 3 , 3 examples processed
Anonymous client 17182 : batch 3 , 24 examples processed
Anonymous client 28564 : batch 3 , 21 examples processed
Anonymous client 13548 : batch 3 , 6 examples processed
Anonymous client 22174 : batch 3 , 9 examples processed
Anonymous client 9510 : batch 4 , 32 examples processed
Anonymous client 10630 : training finished.

Anonymous client 24069 : training start.
Client 6808 : updated the model with server message.
Anonymous client 6808 : training start.
Client 5938 : updated the model with server message.
Anonymous client 5938 : training start.
Client 5384 : updated the model with server message.
Anonymous client 5384 : training start.
Client 5353 : updated the model with server message.
Anonymous client 5353 : training start.
Anonymous client 12261 : batch 1 , 8 examples processed
Anonymous client 19643 : batch 1 , 2 examples processed
Anonymous client 24146 : batch 1 , 8 examples processed
Anonymous client 5938 : batch 1 , 8 examples processed
Anonymous client 5353 : batch 1 , 8 examples processed
Anonymous client 15287 : batch 2 , 4 examples processed
Anonymous client 5278 : batch 1 , 8 examples processed
Anonymous client 24069 : batch 1 , 8 examples processed
Anonymous client 5384 : batch 1 , 8 examples processed
Anonymous client 6808 : batch 1 , 7 examples processed
Anonymous client 12261 : batch 2

Anonymous client 5384 : batch 32 , 253 examples processed
Anonymous client 5384 : batch 33 , 261 examples processed
Anonymous client 5384 : batch 34 , 269 examples processed
Anonymous client 5384 : batch 35 , 277 examples processed
Anonymous client 5384 : batch 36 , 285 examples processed
Anonymous client 5384 : batch 37 , 293 examples processed
Anonymous client 5384 : batch 38 , 301 examples processed
Anonymous client 5384 : batch 39 , 309 examples processed
Anonymous client 5384 : batch 40 , 317 examples processed
Anonymous client 5384 : batch 41 , 325 examples processed
Anonymous client 5384 : batch 42 , 333 examples processed
Anonymous client 5384 : batch 43 , 341 examples processed
Anonymous client 5384 : batch 44 , 349 examples processed
Anonymous client 5384 : batch 45 , 357 examples processed
Anonymous client 5384 : batch 46 , 365 examples processed
Anonymous client 5384 : batch 47 , 373 examples processed
Anonymous client 5384 : batch 48 , 378 examples processed
Anonymous clie

Anonymous client 657 : batch 1 , 8 examples processed
Anonymous client 22573 : batch 1 , 8 examples processed
Anonymous client 23443 : batch 1 , 8 examples processed
Anonymous client 19925 : batch 2 , 8 examples processed
Anonymous client 15308 : batch 2 , 4 examples processed
Anonymous client 6613 : batch 2 , 16 examples processed
Anonymous client 11193 : batch 2 , 8 examples processed
Anonymous client 7591 : batch 2 , 16 examples processed
Anonymous client 657 : batch 2 , 16 examples processed
Anonymous client 22573 : batch 2 , 16 examples processed
Anonymous client 7668 : batch 2 , 4 examples processed
Anonymous client 3051 : batch 2 , 2 examples processed
Anonymous client 23443 : batch 2 , 16 examples processed
Anonymous client 19925 : batch 3 , 12 examples processed
Anonymous client 15308 : batch 3 , 6 examples processed
Anonymous client 6613 : batch 3 , 24 examples processed
Anonymous client 11193 : batch 3 , 12 examples processed
Anonymous client 7591 : batch 3 , 24 examples pro

Anonymous client 25037 : batch 2 , 16 examples processed
Anonymous client 19362 : batch 2 , 14 examples processed
Anonymous client 28570 : batch 2 , 16 examples processed
Anonymous client 9249 : batch 3 , 6 examples processed
Anonymous client 23592 : batch 2 , 16 examples processed
Anonymous client 5341 : batch 2 , 14 examples processed
Anonymous client 13992 : batch 2 , 4 examples processed
Anonymous client 13846 : batch 2 , 4 examples processed
Anonymous client 21425 : batch 4 , 32 examples processed
Anonymous client 8475 : batch 3 , 24 examples processed
Anonymous client 25037 : batch 3 , 24 examples processed
Anonymous client 19362 : batch 3 , 21 examples processed
Anonymous client 23592 : batch 3 , 24 examples processed
Anonymous client 28570 : batch 3 , 24 examples processed
Anonymous client 9249 : training finished. 6  examples processed, loss: 9.61578
Anonymous client 5341 : batch 3 , 22 examples processed
Anonymous client 13992 : batch 3 , 6 examples processed
Anonymous client

Anonymous client 23414 : batch 2 , 9 examples processed
Anonymous client 5375 : batch 2 , 4 examples processed
Anonymous client 26228 : batch 2 , 6 examples processed
Anonymous client 16767 : batch 2 , 2 examples processed
Anonymous client 17499 : batch 2 , 14 examples processed
Anonymous client 483 : batch 2 , 16 examples processed
Anonymous client 28013 : batch 2 , 16 examples processed
Anonymous client 1685 : batch 2 , 16 examples processed
Anonymous client 21858 : training finished. 12  examples processed, loss: 35.6622543
Anonymous client 3479 : batch 3 , 3 examples processed
Anonymous client 23414 : batch 3 , 17 examples processed
Anonymous client 5375 : batch 3 , 6 examples processed
Anonymous client 26228 : batch 3 , 9 examples processed
Anonymous client 17499 : batch 3 , 22 examples processed
Anonymous client 483 : batch 3 , 24 examples processed
Anonymous client 16767 : batch 3 , 3 examples processed
Anonymous client 28013 : batch 3 , 18 examples processed
Anonymous client 16

Anonymous client 22670 : batch 6 , 48 examples processed
Anonymous client 5637 : batch 5 , 40 examples processed
Anonymous client 25832 : batch 5 , 40 examples processed
Anonymous client 29824 : batch 5 , 40 examples processed
Anonymous client 2177 : batch 5 , 40 examples processed
Anonymous client 652 : batch 6 , 42 examples processed
Anonymous client 3507 : batch 6 , 42 examples processed
Anonymous client 22670 : batch 7 , 56 examples processed
Anonymous client 5637 : batch 6 , 48 examples processed
Anonymous client 25832 : batch 6 , 48 examples processed
Anonymous client 29824 : batch 6 , 48 examples processed
Anonymous client 2177 : batch 6 , 48 examples processed
Anonymous client 652 : batch 7 , 50 examples processed
Anonymous client 22670 : batch 8 , 64 examples processed
Anonymous client 3507 : training finished. 42  examples processed, loss: 189.938095
Anonymous client 5637 : batch 7 , 56 examples processed
Anonymous client 25832 : batch 7 , 56 examples processed
Anonymous clie

ResourceExhaustedError: 2 root error(s) found.
  (0) Resource exhausted:  OOM when allocating tensor with shape[8,128,30522] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node StatefulPartitionedCall/while/body/_3364/cond/then/_7918/model/standalone_tf_mobile_bert_mlm_head/predictions/add}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

	 [[StatefulPartitionedCall/while/LoopExecuted/_4518/_41]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

  (1) Resource exhausted:  OOM when allocating tensor with shape[8,128,30522] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
	 [[{{node StatefulPartitionedCall/while/body/_3364/cond/then/_7918/model/standalone_tf_mobile_bert_mlm_head/predictions/add}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

0 successful operations.
0 derived errors ignored. [Op:__inference_pruned_468984]

Function call stack:
pruned -> pruned


### Save the trained model

In [None]:
model_final.keras_model.save(EXPERIMENT_CONFIG['RESULTS_MODEL'])

## Evaluation

In [None]:
test_metric = utils.keras_evaluate(
    model_final.keras_model, test_client_data_all_merged, metric_eval)

print(test_metric)