-
Notifications
You must be signed in to change notification settings - Fork 584
/
client_works.py
295 lines (251 loc) · 13.3 KB
/
client_works.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
# This module disables the PyType analyzer, see
# https://github.com/tensorflow/federated/blob/main/docs/pytype.md for more
# information.
"""Abstractions for client work in learning algorithms."""
from typing import Optional
import attr
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.templates import errors
from tensorflow_federated.python.core.templates import measured_process
from tensorflow_federated.python.learning.templates import tunable
@attr.s(frozen=True)
class ClientResult():
"""A structure containing the result of `ClientWorkProcess.next` computation.
Attributes:
update: The local update to model weights produced by clients.
update_weight: A weight for weighted aggregation of the `update`.
"""
update = attr.ib()
update_weight = attr.ib()
class ClientDataTypeError(TypeError):
"""`TypeError` for incorrect type of client data."""
class ClientResultTypeError(TypeError):
"""`TypeError` for incorrect structure of result of client work."""
class GetHparamsTypeError(TypeError):
"""`TypeError` for incorrect input and output of get_hparams."""
class SetHparamsTypeError(TypeError):
"""`TypeError` for incorrect input and output of get_hparams."""
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _is_allowed_client_data_type(type_spec: computation_types.Type) -> bool:
"""Determines whether a given type is a (possibly nested) sequence type."""
if type_spec.is_sequence():
return type_analysis.is_tensorflow_compatible_type(type_spec.element)
elif type_spec.is_struct():
return all(
_is_allowed_client_data_type(element_type)
for element_type in type_spec.children())
else:
return False
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _type_check_initialize_fn(initialize_fn: computation_base.Computation):
if not initialize_fn.type_signature.result.is_federated():
raise errors.TemplateNotFederatedError(
f'Provided `initialize_fn` must return a federated type, but found '
f'return type:\n{initialize_fn.type_signature.result}\nTip: If you '
f'see a collection of federated types, try wrapping the returned '
f'value in `tff.federated_zip` before returning.')
if initialize_fn.type_signature.result.placement != placements.SERVER:
raise errors.TemplatePlacementError(
f'The state controlled by a `ClientWorkProcess` must be placed at '
f'the SERVER, but found type: {initialize_fn.type_signature.result}.')
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _check_next_fn_is_federated(next_fn: computation_base.Computation):
"""Checks that a given `next_fn` has federated inputs and outputs."""
next_types = (
structure.flatten(next_fn.type_signature.parameter) +
structure.flatten(next_fn.type_signature.result))
if not all([t.is_federated() for t in next_types]):
offending_types = '\n- '.join(
[t for t in next_types if not t.is_federated()])
raise errors.TemplateNotFederatedError(
f'Provided `next_fn` must be a *federated* computation, that is, '
f'operate on `tff.FederatedType`s, but found\n'
f'next_fn with type signature:\n{next_fn.type_signature}\n'
f'The non-federated types are:\n {offending_types}.')
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _type_check_next_fn_parameters(next_fn: computation_base.Computation):
"""Validates the input types of `next_fn` in a `ClientWorkProcess`."""
next_fn_param = next_fn.type_signature.parameter
if not next_fn_param.is_struct():
raise errors.TemplateNextFnNumArgsError(
f'The `next_fn` must have exactly three input arguments, but found '
f'the following input type which is not a Struct: {next_fn_param}.')
if len(next_fn_param) != 3:
next_param_str = '\n- '.join([str(t) for t in next_fn_param])
raise errors.TemplateNextFnNumArgsError(
f'The `next_fn` must have exactly three input arguments, but found '
f'{len(next_fn_param)} input arguments:\n{next_param_str}')
second_next_param = next_fn_param[1]
client_data_param = next_fn_param[2]
if second_next_param.placement != placements.CLIENTS:
raise errors.TemplatePlacementError(
f'The second input argument of `next_fn` must be placed at CLIENTS '
f'but found {second_next_param}.')
if client_data_param.placement != placements.CLIENTS:
raise errors.TemplatePlacementError(
f'The third input argument of `next_fn` must be placed at CLIENTS '
f'but found {client_data_param}.')
if client_data_param.placement != placements.CLIENTS:
raise errors.TemplatePlacementError(
f'The third input argument of `next_fn` must be placed at CLIENTS '
f'but found {client_data_param}.')
if not _is_allowed_client_data_type(client_data_param.member):
raise ClientDataTypeError(
f'The third input argument of `next_fn` must be a sequence or '
f'a structure of squences, but found {client_data_param}.')
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _type_check_next_fn_result(next_fn: computation_base.Computation):
"""Validates the output types of `next_fn` in a `ClientWorkProcess`."""
next_fn_result = next_fn.type_signature.result
if (not next_fn_result.result.is_federated() or
next_fn_result.result.placement != placements.CLIENTS):
raise errors.TemplatePlacementError(
f'The "result" attribute of the return type of `next_fn` must be '
f'placed at CLIENTS, but found {next_fn_result.result}.')
if (not next_fn_result.result.member.is_struct_with_python() or
next_fn_result.result.member.python_container is not ClientResult):
raise ClientResultTypeError(
f'The "result" attribute of the return type of `next_fn` must have '
f'the `ClientResult` container, but found {next_fn_result.result}.')
if next_fn_result.measurements.placement != placements.SERVER:
raise errors.TemplatePlacementError(
f'The "measurements" attribute of return type of `next_fn` must be '
f'placed at SERVER, but found {next_fn_result.measurements}.')
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _type_check_get_hparams_fn(get_hparams_fn: computation_base.Computation,
state_type: computation_types.Type):
"""Validates the type signature of `get_hparams_fn` in `ClientWorkProcess`."""
py_typecheck.check_type(get_hparams_fn, computation_base.Computation)
get_hparams_state_type = get_hparams_fn.type_signature.parameter
if not get_hparams_state_type.is_assignable_from(state_type):
raise GetHparamsTypeError(
f'The input to get_hparams must be compatible with the state type '
f'{state_type}, but found type {get_hparams_state_type}.')
# TODO(b/240314933): Move this (or refactor this) to a more general location.
def _type_check_set_hparams_fn(set_hparams_fn: computation_base.Computation,
state_type: computation_types.Type):
"""Validates the type signature of `set_hparams_fn` in `ClientWorkProcess`."""
py_typecheck.check_type(set_hparams_fn, computation_base.Computation)
set_hparams_parameter = set_hparams_fn.type_signature.parameter
if not set_hparams_parameter.is_struct() or len(set_hparams_parameter) != 2:
raise SetHparamsTypeError(
f'Expected two input arguments to set_hparams, but found '
f'{set_hparams_parameter}.')
set_hparams_state_type = set_hparams_parameter[0]
if not set_hparams_state_type.is_assignable_from(state_type):
raise SetHparamsTypeError(
f'The first input to set_hparams must be compatible with the state '
f'type {state_type}, but found {set_hparams_state_type}.')
set_hparams_result_type = set_hparams_fn.type_signature.result
if not set_hparams_result_type.is_assignable_from(state_type):
raise SetHparamsTypeError(
f'The output of set_hparams must be compatible with the state '
f'of type {state_type} but found {set_hparams_result_type}.')
class ClientWorkProcess(measured_process.MeasuredProcess, tunable.Tunable):
"""A stateful process capturing work at clients during learning.
Client work encapsulates the main work performed by clients as part of a
federated learning algorithm, such as several steps of gradient descent based
on the client data, and returning a update to the initial model weights.
A `ClientWorkProcess` is a `tff.templates.MeasuredProcess` that formalizes the
type signature of `initialize` and `next` for the core work performed by
clients in a learning process.
"""
def __init__(self,
initialize_fn: computation_base.Computation,
next_fn: computation_base.Computation,
*,
get_hparams_fn: Optional[computation_base.Computation] = None,
set_hparams_fn: Optional[computation_base.Computation] = None):
"""Initializes a `ClientWorkProcess`.
The `initialize_fn` and `next_fn` must have the following type signatures:
```
- initialize_fn: ( -> S@SERVER)
- next_fn: (<S@SERVER,
A@CLIENTS,
{D*}@CLIENTS>
->
<state=S@SERVER,
result=ClientResult(B, C)@CLIENTS,
measurements=M@SERVER>)
```
with `A`, `B`, `C`, and `D` not dependent on other types here. `A`
represents a parameter informing the client update (such as a client's
model weights). `D*` is a `tff.SequenceType` of client data.
Note that the output of `next_fn` must have a structure matching
`tff.templates.MeasuredProcessOutput`. The `result` field of this output
has type `tff.learning.templates.ClientResult(B, C)` where `B` represents a
client's update (such as a model update) and `C` represents the weight of
this update when using weighted aggregation across clients.
If provided, the `get_hparams_fn` and `set_hparams_fn` must be non-federated
computations with the following type signatures:
```
- get_hparams_fn: (S -> H)
- set_hparams_fn: (<S, H> -> S)
```
Here, `S` must match the state `S` of `initialize_fn` and `next_fn`, and `H`
represents the hyperparameter type.
Args:
initialize_fn: A `tff.Computation` matching the criteria above.
next_fn: A `tff.Computation` matching the criteria above.
get_hparams_fn: An optional `tff.Computation` matching the criteria above.
set_hparams_fn: An optional `tff.Computation` matching the criteria above.
Raises:
TemplateNotFederatedError: If any of the federated computations provided
do not return a federated type.
TemplateNextFnNumArgsError: If the `next_fn` has an incorrect number
of arguments.
TemplatePlacementError: If any of the federated computations have an
incorrect placement.
ClientDataTypeError: If the third input of `next_fn` is not a sequence
type placed at `CLIENTS`.
ClientResultTypeError: If the second output of `next_fn` does not meet the
criteria outlined above.
GetHparamsTypeError: If the type signature of `get_hparams_fn` does not
meet the criteria above.
SetHparamsTypeError: If the type signature of `set_hparams_fn` does not
meet the criteria above.
"""
super().__init__(initialize_fn, next_fn, next_is_multi_arg=True)
_type_check_initialize_fn(initialize_fn)
_check_next_fn_is_federated(next_fn)
_type_check_next_fn_parameters(next_fn)
_type_check_next_fn_result(next_fn)
state_type = initialize_fn.type_signature.result
if get_hparams_fn is not None:
_type_check_get_hparams_fn(get_hparams_fn, state_type)
if set_hparams_fn is not None:
_type_check_set_hparams_fn(set_hparams_fn, state_type)
self._get_hparams_fn = get_hparams_fn
self._set_hparams_fn = set_hparams_fn
@property
def get_hparams(self) -> computation_base.Computation:
if self._get_hparams_fn is None:
raise NotImplementedError('get_hparams is not implemented for this '
'ClientWorkProcess.')
return self._get_hparams_fn
@property
def set_hparams(self) -> computation_base.Computation:
if self._set_hparams_fn is None:
raise NotImplementedError('set_hparams is not implemented for this '
'ClientWorkProcess.')
return self._set_hparams_fn