/
federated_evaluation.py
380 lines (331 loc) · 14.7 KB
/
federated_evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A simple implementation of federated evaluation."""
import collections
from collections.abc import Callable, Mapping
from typing import Optional, Union
import tensorflow as tf
from tensorflow_federated.python.common_libs import deprecation
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import iterative_process
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning import dataset_reduce
from tensorflow_federated.python.learning.metrics import aggregator
from tensorflow_federated.python.learning.metrics import types
from tensorflow_federated.python.learning.models import functional
from tensorflow_federated.python.learning.models import model_weights as model_weights_lib
from tensorflow_federated.python.learning.models import variable
_SequenceType = computation_types.SequenceType
def build_local_evaluation(
model_fn: Callable[[], variable.VariableModel],
model_weights_type: computation_types.StructType,
batch_type: computation_types.Type,
use_experimental_simulation_loop: bool = False,
) -> computation_base.Computation:
"""Builds the local TFF computation for evaluation of the given model.
This produces an unplaced function that evaluates a
`tff.learning.models.VariableModel`
on a `tf.data.Dataset`. This function can be mapped to placed data, i.e.
is mapped to client placed data in `build_federated_evaluation`.
The TFF type notation for the returned computation is:
```
(<M, D*> → <local_outputs=N, num_examples=tf.int64>)
```
Where `M` is the model weights type structure, `D` is the type structure of a
single data point, and `N` is the type structure of the local metrics.
Args:
model_fn: A no-arg function that returns a
`tff.learning.models.VariableModel`.
model_weights_type: The `tff.Type` of the model parameters that will be used
to initialize the model during evaluation.
batch_type: The type of one entry in the dataset.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation.
Returns:
A federated computation (an instance of `tff.Computation`) that accepts
model parameters and sequential data, and returns the evaluation metrics.
"""
@tensorflow_computation.tf_computation(
model_weights_type, _SequenceType(batch_type)
)
@tf.function
def client_eval(incoming_model_weights, dataset):
"""Returns local outputs after evaluating `model_weights` on `dataset`."""
with tf.init_scope():
model = model_fn()
model_weights = model_weights_lib.ModelWeights.from_model(model)
tf.nest.map_structure(
lambda v, t: v.assign(t), model_weights, incoming_model_weights
)
def reduce_fn(num_examples, batch):
model_output = model.forward_pass(batch, training=False)
if model_output.num_examples is None:
# Compute shape from the size of the predictions if model didn't use the
# batch size.
return (
num_examples
+ tf.shape(model_output.predictions, out_type=tf.int64)[0]
)
else:
return num_examples + tf.cast(model_output.num_examples, tf.int64)
dataset_reduce_fn = dataset_reduce.build_dataset_reduce_fn(
use_experimental_simulation_loop
)
num_examples = dataset_reduce_fn(
reduce_fn, dataset, lambda: tf.zeros([], dtype=tf.int64)
)
model_output = model.report_local_unfinalized_metrics()
return collections.OrderedDict(
local_outputs=model_output, num_examples=num_examples
)
return client_eval
def build_functional_local_evaluation(
model: functional.FunctionalModel,
model_weights_type: computation_types.StructType,
batch_type: Union[
computation_types.StructType, computation_types.TensorType
],
) -> computation_base.Computation:
"""Creates client evaluation logic for a functional model.
This produces an unplaced function that evaluates a
`tff.learning.models.FunctionalModel` on a `tf.data.Dataset`. This function
can be mapped to placed data.
The TFF type notation for the returned computation is:
```
(<M, D*> → <local_outputs=N>)
```
Where `M` is the model weights type structure, `D` is the type structure of a
single data point, and `N` is the type structure of the local metrics.
Args:
model: A `tff.learning.models.FunctionalModel`.
model_weights_type: The `tff.Type` of the model parameters that will be used
in the forward pass.
batch_type: The type of one entry in the dataset.
Returns:
A federated computation (an instance of `tff.Computation`) that accepts
model parameters and sequential data, and returns the evaluation metrics.
"""
@tensorflow_computation.tf_computation(
model_weights_type, _SequenceType(batch_type)
)
@tf.function
def local_eval(weights, dataset):
metrics_state = model.initialize_metrics_state()
for batch in iter(dataset):
if isinstance(batch, Mapping):
x = batch['x']
y = batch['y']
else:
x, y = batch
batch_output = model.predict_on_batch(weights, x, training=False)
batch_loss = model.loss(output=batch_output, label=y)
predictions = tf.nest.flatten(batch_output)[0]
batch_num_examples = tf.shape(predictions)[0]
# TODO: b/272099796 - Update `update_metrics_state` of FunctionalModel
metrics_state = model.update_metrics_state(
metrics_state,
batch_output=variable.BatchOutput(
loss=batch_loss,
predictions=batch_output,
num_examples=batch_num_examples,
),
labels=y,
)
unfinalized_metrics = metrics_state
return unfinalized_metrics
return local_eval
@deprecation.deprecated(
'`tff.learning.build_federated_evaluation` is deprecated, use '
'`tff.learning.algorithms.build_fed_eval` instead.'
)
def build_federated_evaluation(
model_fn: Union[
Callable[[], variable.VariableModel], functional.FunctionalModel
],
broadcast_process: Optional[measured_process.MeasuredProcess] = None,
metrics_aggregator: Optional[types.MetricsAggregatorType] = None,
use_experimental_simulation_loop: bool = False,
) -> computation_base.Computation:
"""Builds the TFF computation for federated evaluation of the given model.
Args:
model_fn: A no-arg function that returns a
`tff.learning.models.VariableModel`, or an instance of a
`tff.learning.models.FunctionalModel`. When passing a callable, the
callable must *not* capture TensorFlow tensors or variables and use them.
The model must be constructed entirely from scratch on each invocation,
returning the same pre-constructed model each call will result in an
error.
broadcast_process: A `tff.templates.MeasuredProcess` that broadcasts the
model weights on the server to the clients. It must support the signature
`(input_values@SERVER -> output_values@CLIENTS)` and have empty state. If
set to default None, the server model is broadcast to the clients using
the default tff.federated_broadcast.
metrics_aggregator: An optional function that takes in the metric finalizers
(i.e., `tff.learning.models.VariableModel.metric_finalizers()`) and a
`tff.types.StructWithPythonType` of the unfinalized metrics (i.e., the TFF
type of
`tff.learning.models.VariableModel.report_local_unfinalized_metrics()`),
and returns a federated TFF computation of the following type signature
`local_unfinalized_metrics@CLIENTS -> aggregated_metrics@SERVER`. If
`None`, uses `tff.learning.metrics.sum_then_finalize`, which returns a
federated TFF computation that sums the unfinalized metrics from
`CLIENTS`, and then applies the corresponding metric finalizers at
`SERVER`.
use_experimental_simulation_loop: Controls the reduce loop function for
input dataset. An experimental reduce loop is used for simulation.
Returns:
A federated computation (an instance of `tff.Computation`) that accepts
model parameters and federated data, and returns the evaluation metrics.
"""
if not callable(model_fn):
if not isinstance(model_fn, functional.FunctionalModel):
raise TypeError(
'If `model_fn` is not a callable, it must be an instance '
f'tff.learning.models.FunctionalModel. Got {type(model_fn)}'
)
if broadcast_process is not None:
if not isinstance(broadcast_process, measured_process.MeasuredProcess):
raise ValueError(
'`broadcast_process` must be a `MeasuredProcess`, got '
f'{type(broadcast_process)}.'
)
if iterative_process.is_stateful(broadcast_process):
raise ValueError(
'Cannot create a federated evaluation with a stateful '
'broadcast process, must be stateless (have empty state), has state: '
f'{broadcast_process.initialize.type_signature.result!r}'
)
if metrics_aggregator is None:
metrics_aggregator = aggregator.sum_then_finalize
if not callable(model_fn):
return _build_functional_federated_evaluation(
model=model_fn,
broadcast_process=broadcast_process,
metrics_aggregator=metrics_aggregator,
)
else:
return _build_federated_evaluation(
model_fn=model_fn,
broadcast_process=broadcast_process,
metrics_aggregator=metrics_aggregator,
use_experimental_simulation_loop=use_experimental_simulation_loop,
)
def _build_federated_evaluation(
*,
model_fn: Callable[[], variable.VariableModel],
broadcast_process: Optional[measured_process.MeasuredProcess],
metrics_aggregator: types.MetricsAggregatorType,
use_experimental_simulation_loop: bool,
) -> computation_base.Computation:
"""Builds a federated evaluation computation for a `tff.learning.models.VariableModel`."""
# Construct the model first just to obtain the metadata and define all the
# types needed to define the computations that follow.
# TODO: b/124477628 - Ideally replace the need for stamping throwaway models
# with some other mechanism.
with tf.Graph().as_default():
model = model_fn()
model_weights_type = model_weights_lib.weights_type_from_model(model)
batch_type = computation_types.tensorflow_to_type(model.input_spec)
metrics_aggregation_computation = metrics_aggregator(
model.metric_finalizers(),
)
local_eval = build_local_evaluation(
model_fn=model_fn,
model_weights_type=model_weights_type,
batch_type=batch_type,
use_experimental_simulation_loop=use_experimental_simulation_loop,
)
@federated_computation.federated_computation(
computation_types.FederatedType(model_weights_type, placements.SERVER),
computation_types.FederatedType(
_SequenceType(batch_type), placements.CLIENTS
),
)
def server_eval(server_model_weights, federated_dataset):
if broadcast_process is not None:
# TODO: b/179091838 - Zip the measurements from the broadcast_process with
# the result of `model_metrics` below to avoid dropping these metrics.
broadcast_output = broadcast_process.next(
broadcast_process.initialize(), server_model_weights
)
client_outputs = intrinsics.federated_map(
local_eval, (broadcast_output.result, federated_dataset)
)
else:
client_outputs = intrinsics.federated_map(
local_eval,
[
intrinsics.federated_broadcast(server_model_weights),
federated_dataset,
],
)
model_metrics = metrics_aggregation_computation(
client_outputs.local_outputs
)
return intrinsics.federated_zip(collections.OrderedDict(eval=model_metrics))
return server_eval
def _build_functional_federated_evaluation(
*,
model: functional.FunctionalModel,
broadcast_process: Optional[measured_process.MeasuredProcess],
metrics_aggregator: types.MetricsAggregatorType,
) -> computation_base.Computation:
"""Builds a federated evaluation computation for a functional model."""
def ndarray_to_tensorspec(ndarray):
return tf.TensorSpec(
shape=ndarray.shape, dtype=tf.dtypes.as_dtype(ndarray.dtype)
)
weights_spec = tf.nest.map_structure(
ndarray_to_tensorspec, model.initial_weights
)
weights_type = computation_types.tensorflow_to_type(weights_spec)
batch_type = computation_types.tensorflow_to_type(model.input_spec)
local_eval = build_functional_local_evaluation(
model,
weights_type,
batch_type, # pytype: disable=wrong-arg-types
)
@federated_computation.federated_computation(
computation_types.FederatedType(weights_type, placements.SERVER),
computation_types.FederatedType(
_SequenceType(batch_type), placements.CLIENTS
),
)
def federated_eval(server_weights, client_data):
if broadcast_process is not None:
# TODO: b/179091838 - Zip the measurements from the broadcast_process with
# the result of `model_metrics` below to avoid dropping these metrics.
broadcast_output = broadcast_process.next(
broadcast_process.initialize(), server_weights
)
client_weights = broadcast_output.result
else:
client_weights = intrinsics.federated_broadcast(server_weights)
unfinalized_metrics = intrinsics.federated_map(
local_eval, (client_weights, client_data)
)
metrics_aggregation_fn = metrics_aggregator(
model.finalize_metrics,
unfinalized_metrics.type_signature.member, # pytype: disable=attribute-error
)
finalized_metrics = metrics_aggregation_fn(unfinalized_metrics)
return intrinsics.federated_zip(
collections.OrderedDict(eval=finalized_metrics)
)
return federated_eval