/
emnist_fedavg_main.py
192 lines (162 loc) · 6.27 KB
/
emnist_fedavg_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple FedAvg to train EMNIST.
This is intended to be a minimal stand-alone experiment script demonstrating
usage of TFF's Federated Compute API for a from-scratch Federated Avearging
implementation.
"""
import collections
import functools
from absl import app
from absl import flags
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
from examplessimple_fedavg import simple_fedavg_tff
# Training hyperparameters
flags.DEFINE_integer('total_rounds', 256, 'Number of total training rounds.')
flags.DEFINE_integer('rounds_per_eval', 1, 'How often to evaluate')
flags.DEFINE_integer(
'train_clients_per_round', 2, 'How many clients to sample per round.'
)
flags.DEFINE_integer(
'client_epochs_per_round',
1,
'Number of epochs in the client to take per round.',
)
flags.DEFINE_integer('batch_size', 16, 'Batch size used on the client.')
flags.DEFINE_integer('test_batch_size', 128, 'Minibatch size of test data.')
# Optimizer configuration (this defines one or more flags per optimizer).
flags.DEFINE_float('server_learning_rate', 1.0, 'Server learning rate.')
flags.DEFINE_float('client_learning_rate', 0.1, 'Client learning rate.')
FLAGS = flags.FLAGS
def evaluate(keras_model, test_dataset):
"""Evaluate the acurracy of a keras model on a test dataset."""
metric = tf.keras.metrics.SparseCategoricalAccuracy(name='test_accuracy')
for batch in test_dataset:
predictions = keras_model(batch['x'])
metric.update_state(y_true=batch['y'], y_pred=predictions)
return metric.result()
def get_emnist_dataset():
"""Loads and preprocesses the EMNIST dataset.
Returns:
A `(emnist_train, emnist_test)` tuple where `emnist_train` is a
`tff.simulation.datasets.ClientData` object representing the training data
and `emnist_test` is a single `tf.data.Dataset` representing the test data
of all clients.
"""
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
only_digits=True
)
def element_fn(element):
return collections.OrderedDict(
x=tf.expand_dims(element['pixels'], -1), y=element['label']
)
def preprocess_train_dataset(dataset):
# Use buffer_size same as the maximum client dataset size,
# 418 for Federated EMNIST
return (
dataset.map(element_fn)
.shuffle(buffer_size=418)
.repeat(count=FLAGS.client_epochs_per_round)
.batch(FLAGS.batch_size, drop_remainder=False)
)
def preprocess_test_dataset(dataset):
return dataset.map(element_fn).batch(
FLAGS.test_batch_size, drop_remainder=False
)
emnist_train = emnist_train.preprocess(preprocess_train_dataset)
emnist_test = preprocess_test_dataset(
emnist_test.create_tf_dataset_from_all_clients()
)
return emnist_train, emnist_test
def create_original_fedavg_cnn_model(only_digits=True):
"""The CNN model used in https://arxiv.org/abs/1602.05629.
Args:
only_digits: If True, uses a final layer with 10 outputs, for use with the
digits only EMNIST dataset. If False, uses 62 outputs for the larger
dataset.
Returns:
An uncompiled `tf.keras.Model`.
"""
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format,
)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu,
)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
def server_optimizer_fn():
return tf.keras.optimizers.SGD(learning_rate=FLAGS.server_learning_rate)
def client_optimizer_fn():
return tf.keras.optimizers.SGD(learning_rate=FLAGS.client_learning_rate)
def main(argv):
if len(argv) > 1:
raise app.UsageError('Too many command-line arguments.')
tff.backends.native.set_sync_local_cpp_execution_context()
train_data, test_data = get_emnist_dataset()
def tff_model_fn():
"""Constructs a fully initialized model for use in federated averaging."""
keras_model = create_original_fedavg_cnn_model(only_digits=True)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
return tff.learning.models.from_keras_model(
keras_model,
loss=loss,
metrics=metrics,
input_spec=train_data.element_type_structure,
)
iterative_process = simple_fedavg_tff.build_federated_averaging_process(
tff_model_fn, server_optimizer_fn, client_optimizer_fn
)
server_state = iterative_process.initialize()
# Keras model that represents the global model we'll evaluate test data on.
keras_model = create_original_fedavg_cnn_model(only_digits=True)
for round_num in range(FLAGS.total_rounds):
sampled_clients = np.random.choice(
train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False
)
sampled_train_data = [
train_data.create_tf_dataset_for_client(client)
for client in sampled_clients
]
server_state, train_metrics = iterative_process.next(
server_state, sampled_train_data
)
print(f'Round {round_num}')
print(f'\tTraining metrics: {train_metrics}')
if round_num % FLAGS.rounds_per_eval == 0:
server_state.model.assign_weights_to(keras_model)
accuracy = evaluate(keras_model, test_data)
print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
if __name__ == '__main__':
app.run(main)