-
Notifications
You must be signed in to change notification settings - Fork 578
/
mean.py
282 lines (237 loc) · 10.6 KB
/
mean.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright 2020, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for mean."""
import collections
import typing
from typing import Optional
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
class MeanFactory(factory.WeightedAggregationFactory):
"""Aggregation factory for weighted mean.
The created `tff.templates.AggregationProcess` computes the weighted mean of
values placed at `CLIENTS`, and outputs the mean placed at `SERVER`.
The input arguments of the `next` attribute of the process returned by
`create` are `<state, value, weight>`, where `weight` is a scalar broadcasted
to the structure of `value`, and the weighted mean refers to the expression
`sum(value * weight) / sum(weight)`.
The implementation is parameterized by two inner aggregation factories
responsible for the summations above, with the following high-level steps.
- Multiplication of `value` and `weight` at `CLIENTS`.
- Delegation to inner `value_sum_factory` and `weight_sum_factory` to
realize the sum of weighted values and weights.
- Division of summed weighted values and summed weights at `SERVER`.
Note that the the division at `SERVER` can protect against division by 0, as
specified by `no_nan_division` constructor argument.
The `state` is the composed `state` of the aggregation processes created by
the two inner aggregation factories. The same holds for `measurements`.
"""
def __init__(
self,
value_sum_factory: Optional[factory.UnweightedAggregationFactory] = None,
weight_sum_factory: Optional[factory.UnweightedAggregationFactory] = None,
no_nan_division: bool = False,
):
"""Initializes `MeanFactory`.
Args:
value_sum_factory: An optional
`tff.aggregators.UnweightedAggregationFactory` responsible for summation
of weighted values. If not specified, `tff.aggregators.SumFactory` is
used.
weight_sum_factory: An optional
`tff.aggregators.UnweightedAggregationFactory` responsible for summation
of weights. If not specified, `tff.aggregators.SumFactory` is used.
no_nan_division: A bool. If True, the computed mean is 0 if sum of weights
is equal to 0.
Raises:
TypeError: If provided `value_sum_factory` or `weight_sum_factory` is not
an instance of `tff.aggregators.UnweightedAggregationFactory`.
"""
if value_sum_factory is None:
value_sum_factory = sum_factory.SumFactory()
py_typecheck.check_type(
value_sum_factory, factory.UnweightedAggregationFactory
)
self._value_sum_factory = value_sum_factory
if weight_sum_factory is None:
weight_sum_factory = sum_factory.SumFactory()
py_typecheck.check_type(
weight_sum_factory, factory.UnweightedAggregationFactory
)
self._weight_sum_factory = weight_sum_factory
py_typecheck.check_type(no_nan_division, bool)
self._no_nan_division = no_nan_division
def create(
self, value_type: factory.ValueType, weight_type: factory.ValueType
) -> aggregation_process.AggregationProcess:
_check_value_type(value_type)
type_args = typing.get_args(factory.ValueType)
py_typecheck.check_type(weight_type, type_args)
value_sum_process = self._value_sum_factory.create(value_type)
weight_sum_process = self._weight_sum_factory.create(weight_type)
@federated_computation.federated_computation()
def init_fn():
state = collections.OrderedDict(
value_sum_process=value_sum_process.initialize(),
weight_sum_process=weight_sum_process.initialize(),
)
return intrinsics.federated_zip(state)
@federated_computation.federated_computation(
init_fn.type_signature.result,
computation_types.FederatedType(value_type, placements.CLIENTS),
computation_types.FederatedType(weight_type, placements.CLIENTS),
)
def next_fn(state, value, weight):
# Client computation.
weighted_value = intrinsics.federated_map(_mul, (value, weight))
# Inner aggregations.
value_output = value_sum_process.next(
state['value_sum_process'], weighted_value
)
weight_output = weight_sum_process.next(
state['weight_sum_process'], weight
)
# Server computation.
weighted_mean_value = intrinsics.federated_map(
_div_no_nan if self._no_nan_division else _div,
(value_output.result, weight_output.result),
)
# Output preparation.
state = collections.OrderedDict(
value_sum_process=value_output.state,
weight_sum_process=weight_output.state,
)
measurements = collections.OrderedDict(
mean_value=value_output.measurements,
mean_weight=weight_output.measurements,
)
return measured_process.MeasuredProcessOutput(
intrinsics.federated_zip(state),
weighted_mean_value,
intrinsics.federated_zip(measurements),
)
return aggregation_process.AggregationProcess(init_fn, next_fn)
class UnweightedMeanFactory(factory.UnweightedAggregationFactory):
"""Aggregation factory for unweighted mean.
The created `tff.templates.AggregationProcess` computes the unweighted mean of
values placed at `CLIENTS`, and outputs the mean placed at `SERVER`.
The input arguments of the `next` attribute of the process returned by
`create` are `<state, value>`, and the unweighted mean refers to the
expression `sum(value * weight) / count(value)` where `count(value)` is the
cardinality of the `CLIENTS` placement.
The implementation is parameterized by an inner aggregation factory
responsible for the summation of values.
"""
def __init__(
self,
value_sum_factory: Optional[factory.UnweightedAggregationFactory] = None,
count_sum_factory: Optional[factory.UnweightedAggregationFactory] = None,
):
"""Initializes `UnweightedMeanFactory`.
Args:
value_sum_factory: An optional
`tff.aggregators.UnweightedAggregationFactory` responsible for summation
of values. If not specified, `tff.aggregators.SumFactory` is used.
count_sum_factory: An optional
`tff.aggregators.UnweightedAggregationFactory` responsible for summation
of ones to determine the cardinality of the `CLIENTS` placement. If not
specified, `tff.aggregators.SumFactory` is used.
Raises:
TypeError: If provided `value_sum_factory` or `count_sum_factory` is not
an instance of `tff.aggregators.UnweightedAggregationFactory`.
"""
if value_sum_factory is None:
value_sum_factory = sum_factory.SumFactory()
py_typecheck.check_type(
value_sum_factory, factory.UnweightedAggregationFactory
)
self._value_sum_factory = value_sum_factory
if count_sum_factory is None:
count_sum_factory = sum_factory.SumFactory()
py_typecheck.check_type(
count_sum_factory, factory.UnweightedAggregationFactory
)
self._count_sum_factory = count_sum_factory
def create(
self, value_type: factory.ValueType
) -> aggregation_process.AggregationProcess:
_check_value_type(value_type)
value_sum_process = self._value_sum_factory.create(value_type)
count_sum_process = self._count_sum_factory.create(
computation_types.TensorType(np.int32)
)
@federated_computation.federated_computation()
def init_fn():
return intrinsics.federated_zip(
(value_sum_process.initialize(), count_sum_process.initialize())
)
@federated_computation.federated_computation(
init_fn.type_signature.result,
computation_types.FederatedType(value_type, placements.CLIENTS),
)
def next_fn(state, value):
value_sum_state, count_sum_state = state
value_sum_output = value_sum_process.next(value_sum_state, value)
count_sum_output = count_sum_process.next(
count_sum_state, intrinsics.federated_value(1, placements.CLIENTS)
)
mean_value = intrinsics.federated_map(
_div, (value_sum_output.result, count_sum_output.result)
)
state = intrinsics.federated_zip(
(value_sum_output.state, count_sum_output.state)
)
measurements = intrinsics.federated_zip(
collections.OrderedDict(
mean_value=value_sum_output.measurements,
mean_count=count_sum_output.measurements,
)
)
return measured_process.MeasuredProcessOutput(
state, mean_value, measurements
)
return aggregation_process.AggregationProcess(init_fn, next_fn)
def _check_value_type(value_type):
type_args = typing.get_args(factory.ValueType)
py_typecheck.check_type(value_type, type_args)
if not type_analysis.is_structure_of_floats(value_type):
raise TypeError(
'All values in provided value_type must be of floating '
f'dtype. Provided value_type: {value_type}'
)
@tensorflow_computation.tf_computation()
def _mul(value, weight):
return tf.nest.map_structure(lambda x: x * tf.cast(weight, x.dtype), value)
@tensorflow_computation.tf_computation()
def _div(weighted_value_sum, weight_sum):
return tf.nest.map_structure(
lambda x: tf.math.divide(x, tf.cast(weight_sum, x.dtype)),
weighted_value_sum,
)
@tensorflow_computation.tf_computation()
def _div_no_nan(weighted_value_sum, weight_sum):
return tf.nest.map_structure(
lambda x: tf.math.divide_no_nan(x, tf.cast(weight_sum, x.dtype)),
weighted_value_sum,
)