/
sampling.py
511 lines (447 loc) · 20.2 KB
/
sampling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
# Copyright 2021, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Aggregator for sampling of CLIENT placed values."""
import collections
from typing import Any, Optional
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_analysis
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import type_transformations
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
# A sentinel value used it indicate the computation was _not_ created with a
# fixed seed. In this case the `accumulate` function should generate a seed
# based on the timestamp around when the first sample was seen.
SEED_SENTINEL = -1
def _is_tensor_or_structure_of_tensors(
value_type: computation_types.Type,
) -> bool:
"""Return True if `value_type` is a TensorType or structure of TensorTypes."""
# TODO: b/181365504 - relax this to allow `StructType` once a `Struct` can be
# returned from `tf.function` decorated methods.
def is_tensor_or_struct_with_py_type(
type_spec: computation_types.Type,
) -> bool:
return isinstance(
type_spec,
(
computation_types.TensorType,
computation_types.StructWithPythonType,
),
)
return type_analysis.contains_only(
value_type, is_tensor_or_struct_with_py_type
)
def build_reservoir_type(
sample_value_type: computation_types.Type,
) -> computation_types.Type:
"""Create the TFF type for the reservoir's state.
`UnweightedReservoirSamplingFactory` will use this type as the "state" type in
a `tff.federated_aggregate` (an input to `accumulate`, `merge` and `report`).
Args:
sample_value_type: The `tff.Type` of the values that will be aggregated from
clients.
Returns:
A `collection.OrderedDict` with three keys:
random_seed: A 2-tuple of `tf.int64` scalars. This keeps track of the
`seed` parameter for `tf.random.stateless_uniform` calls for
sampling during aggregation.
random_values: A 1-d tensor of `int32` values randomly generated from
`tf.random.stateless_uniform`. The size of the tensor will be the
same as the first dimension of each leaf in `samples` (these can be
thought of as parallel lists). These values are used to determine
whether a sample stays in the reservoir, or is evicted, as the values
are aggregated. If the i-th value of this list is evicted, then the
i-th (in the first dimension) tensor in each of the `samples` structure
leaves must be evicted.
samples: A tensor or structure of tensors representing the actual sampled
values. If a structure, the shape of the structure matches that of
`sample_value_type`. All tensors have one additional dimension prepended
which has an unknown size. This will be used to concatenate samples and
store them in the reservoir.
"""
if not _is_tensor_or_structure_of_tensors(sample_value_type):
raise TypeError(
'Cannot create a reservoir for type structure. Sample type '
'must only contain `TensorType` or `StructWithPythonType`, '
f'got a {sample_value_type!r}.'
)
def add_unknown_dimension(t):
if isinstance(t, computation_types.TensorType):
return (
computation_types.TensorType(dtype=t.dtype, shape=(None,) + t.shape),
True,
)
return t, False
# TODO: b/181155367 - creating a value from a type for the `zero` is a common
# pattern for users of `tff.federated_aggregate` that could be made easier
# for TFF users. Replace this once such helper exists.
return computation_types.to_type(
collections.OrderedDict(
random_seed=computation_types.TensorType(np.int64, shape=[2]),
random_values=computation_types.TensorType(np.int32, shape=[None]),
samples=type_transformations.transform_type_postorder(
sample_value_type, add_unknown_dimension
)[0],
)
) # pytype: disable=bad-return-type
def build_initial_sample_reservoir(
sample_value_type: computation_types.Type, seed: Optional[Any] = None
):
"""Build up the initial state of the reservoir for sampling.
Args:
sample_value_type: The type structure of the values that will be sampled.
seed: An optional tensor, or Python value convertible to a tensor, that
serves as the initial seed to the random process.
Returns:
A value structure containing the algebraic zero for samples and metadata
used during reservoir sampling.
"""
@tensorflow_computation.tf_computation
def initialize():
# Allow fixed seeds, otherwise set a sentinel that signals a seed should be
# generated upon the first `accumulate` call of the `federated_aggregate`.
if seed is None:
real_seed = tf.convert_to_tensor(SEED_SENTINEL, dtype=tf.int64)
elif tf.is_tensor(seed):
real_seed = tf.cast(seed, dtype=tf.int64)
else:
real_seed = tf.convert_to_tensor(seed, dtype=tf.int64)
def zero_for_tensor_type(t: computation_types.TensorType):
"""Add an extra first dimension to create a tensor that collects samples.
The first dimension will have size `0` for the algebraic zero, resulting
in an "empty" tensor. This will be conctenated as samples fill the
reservoir.
Args:
t: A `tff.TensorType` to build a sampling zero value for.
Returns:
A tensor whose rank is one larger than before, and whose first dimension
is zero.
Raises:
`TypeError` if `t` is not a `tff.TensorType`.
ValueError: If `t.shape` is `None`'
"""
if not isinstance(t, computation_types.TensorType):
raise TypeError(f'Cannot create zero for non TesnorType: {type(t)}')
if t.shape is None:
raise ValueError('Expected `t.shape` to not be `None`.')
return tf.zeros((0,) + t.shape, dtype=t.dtype)
try:
initial_samples = type_conversions.structure_from_tensor_type_tree(
zero_for_tensor_type, sample_value_type
)
except ValueError as e:
raise TypeError(
'Cannot build initial reservoir for structure that has '
'types other than StructWithPythonType or TensorType, '
f'got {sample_value_type!r}.'
) from e
return collections.OrderedDict(
random_seed=tf.fill(dims=(2,), value=real_seed),
random_values=tf.zeros([0], tf.int32),
samples=initial_samples,
)
return initialize()
def _build_sample_value_computation(
value_type: computation_types.Type, sample_size: int
) -> computation_base.Computation:
"""Builds the `accumulate` computation for sampling."""
reservoir_type = build_reservoir_type(value_type)
def add_sample(reservoir, new_seed, sample_random_value, sample):
"""Add a sample to the reservoir state.
Args:
reservoir: The reservoir state holding all samples selected during the
aggregation process.
new_seed: A `tf.int64` scalar representing the new seed for the next
`tf.random.stateless_uniform` call.
sample_random_value: A `tf.int32` scalar representing the sampling
identifier for this sample.
sample: The sample value being aggregated into the reservoir state.
Returns:
A `collection.OrderedDict` of the new reservoir state containing the
sample.
"""
new_random_values = tf.concat(
[reservoir['random_values'], sample_random_value], axis=0
)
new_samples = tf.nest.map_structure(
lambda a, b: tf.concat([a, tf.expand_dims(b, axis=0)], axis=0),
reservoir['samples'],
sample,
)
return collections.OrderedDict(
random_seed=new_seed,
random_values=new_random_values,
samples=new_samples,
)
def pop_one_minimum_value(reservoir):
"""Remove one element from the reservoir based on the minimum value."""
size_after_pop = tf.size(reservoir['random_values']) - 1
_, indices = tf.nn.top_k(
reservoir['random_values'], k=size_after_pop, sorted=False
)
return collections.OrderedDict(
random_seed=reservoir['random_seed'],
random_values=tf.gather(reservoir['random_values'], indices),
samples=tf.nest.map_structure(
lambda t: tf.gather(t, indices), reservoir['samples']
),
)
def initialize_seed():
"""Generate a seed based on the current millisecond timestamp."""
# tf.timestamp() returns fractional second, which will be quantized
# into a tf.int64 value for the random state seed.
scale_factor = 1_000_000.0
quantized_fractional_seconds = tf.cast(
tf.timestamp() * scale_factor, tf.int64
)
return tf.fill(dims=(2,), value=quantized_fractional_seconds)
@tensorflow_computation.tf_computation(reservoir_type, value_type)
@tf.function
def perform_sampling(reservoir, sample):
if tf.reduce_all(tf.equal(reservoir['random_seed'], SEED_SENTINEL)):
seed = initialize_seed()
else:
seed = reservoir['random_seed']
# Pick a new random number for the incoming sample, and advance the seed
# for the next sample.
sample_random_value = tf.random.stateless_uniform(
shape=(1,), minval=None, seed=seed, dtype=tf.int32
)
new_seed = tf.stack(
[seed[0], tf.squeeze(tf.cast(sample_random_value, tf.int64))]
)
# If the reservoir isn't full, add the sample.
if tf.less(tf.size(reservoir['random_values']), sample_size):
return add_sample(reservoir, new_seed, sample_random_value, sample)
else:
# Determine if the random value for this sample belongs in the reservoir:
# random value larger than the smallest see so far. Or if the sample
# should be discarded: its random value is smaller than the smallest we've
# already seen.
min_reservoir_value = tf.reduce_min(reservoir['random_values'])
if sample_random_value < min_reservoir_value:
return collections.OrderedDict(reservoir, random_seed=new_seed)
reservoir = pop_one_minimum_value(reservoir)
return add_sample(reservoir, new_seed, sample_random_value, sample)
return perform_sampling
def build_merge_samples_computation(
value_type: computation_types.Type, sample_size: int
) -> computation_base.Computation:
"""Builds the `merge` computation for a sampling."""
reservoir_type = build_reservoir_type(value_type)
@tensorflow_computation.tf_computation(reservoir_type, reservoir_type)
@tf.function
def merge_samples(a, b):
# First concatenate all the values together. If the size of the resulting
# structure is less than the sample size we don't need to do anything else.
merged_random_values = tf.concat(
[a['random_values'], b['random_values']], axis=0
)
merged_samples = tf.nest.map_structure(
lambda x, y: tf.concat([x, y], axis=0), a['samples'], b['samples']
)
# `random_seed` is no longer used, but we need to keep the structure
# for this reduction method. Arbitrarily forward the seed from `a`.
forwarded_random_seed = a['random_seed']
# If the reservoir isn't full, unconditionally add this sample to the
# reservoir.
if tf.size(merged_random_values) <= sample_size:
return collections.OrderedDict(
random_seed=forwarded_random_seed,
random_values=merged_random_values,
samples=merged_samples,
)
# Otherwise we need to select just the top values based on sample size.
_, indices = tf.nn.top_k(merged_random_values, sample_size, sorted=False)
selection_mask = tf.scatter_nd(
indices=tf.expand_dims(indices, axis=-1),
updates=tf.fill(dims=tf.shape(indices), value=True),
shape=tf.shape(merged_random_values),
)
selected_random_values = tf.boolean_mask(
merged_random_values, mask=selection_mask
)
selected_samples = tf.nest.map_structure(
lambda t: tf.boolean_mask(t, mask=selection_mask), merged_samples
)
return collections.OrderedDict(
random_seed=forwarded_random_seed,
random_values=selected_random_values,
samples=selected_samples,
)
return merge_samples
def _build_finalize_sample_computation(
value_type: computation_types.Type,
return_sampling_metadata: bool = False,
) -> computation_base.Computation:
"""Builds the `report` computation for sampling."""
reservoir_type = build_reservoir_type(value_type)
@tensorflow_computation.tf_computation(reservoir_type)
@tf.function
def finalize_samples(reservoir):
if return_sampling_metadata: # Return the entire reservoir sampling state.
return reservoir
# Drop all the container extra data and just return the sampled values.
return reservoir['samples']
return finalize_samples
def _build_check_non_finite_leaves_computation(
value_type: computation_types.Type,
) -> computation_base.Computation:
"""Builds the computation for checking non-finite leaves in the client value.
Args:
value_type: The `tff.typs.Type` of the client value. Must only contain
`tff.types.TensorType`s or `tff.types.StructWithPythonType`s.
Returns:
A TFF computation (constructed by the `tff.tf_computation` decoration) that
takes in a client-side value as input, and returns a value of the same
structure as the client value, with all the leaves being a `tf.int64` 0/1
scalar tensor indicating whether the corresponding leaf tensor in the input
client value has any non-finite (`NaN` or `Inf`) value.
Raises:
TypeError: if `value_type` contains types other than `tff.types.TensorType`
or `tff.types.StructWithPythonType`.
"""
if not _is_tensor_or_structure_of_tensors(value_type):
raise TypeError(
'Cannot check non-finite leaves for the client value. Expected the '
'client value type to only contain `TensorType`s or '
f'`StructWithPythonType`s, got a {value_type!r}.'
)
@tensorflow_computation.tf_computation(value_type)
@tf.function
def check_non_finite_leaves(client_value):
def is_non_finite(leaf_tensor: tf.Tensor) -> tf.Tensor:
"""Returns True if `leaf_tensor` has at least one non-finite value."""
# `tf.math.is_finite` only works for tensors of float dtype. This is
# because the type of `np.nan` or `np.inf` is float, so it only exists in
# tensors of float dtype.
if leaf_tensor.dtype.is_floating:
# TODO: b/201213657 - replaces `tf.math.is_finite` by a memory-efficient
# way of checking finite tensors.
return tf.math.logical_not(
tf.reduce_all(tf.math.is_finite(leaf_tensor))
)
return tf.constant(False)
if isinstance(client_value, tf.Tensor):
return tf.cast(is_non_finite(client_value), tf.int64)
else:
# The returned structure is the same as `client_value`, but with all the
# leaves being an integer 0/1 scalar tensor indicating whether that leaf
# tensor has any non-finite value.
return tf.nest.map_structure(
lambda leaf_tensor: tf.cast(is_non_finite(leaf_tensor), tf.int64),
client_value,
)
return check_non_finite_leaves
class UnweightedReservoirSamplingFactory(factory.UnweightedAggregationFactory):
"""An `UnweightedAggregationFactory` for reservoir sampling values.
The created `tff.templates.AggregationProcess` samples values placed at
`CLIENTS`, and outputs the sample placed at `SERVER`.
The process has empty `state`. The `measurements` of this factory counts the
number of non-finite (`NaN` or `Inf` values) leaves in the client values
*before* sampling. Specifically, the returned `measurements` has the same
structure as the client value, and every leaf node is a `tf.int64` scalar
tensor counting the number of clients having non-finite value in that leaf.
For example, suppose we are aggregating from three clients:
```
client_value_1 = collections.OrderedDict(a=[1.0, 2.0], b=[1.0, np.nan])
client_value_2 = collections.OrderedDict(a=[np.nan, np.inf], b=[1.0, 2.0])
client_value_3 = collections.OrderedDict(a=[1.0, 2.0], b=[np.inf, np.nan])
```
Then `measurements` will be:
```
collections.OrderedDict(a=tf.constant(1, dtype=int64),
b=tf.constant(2, dtype=int64)
```
For more about reservoir sampling see
https://en.wikipedia.org/wiki/Reservoir_sampling.
"""
def __init__(self, sample_size: int, return_sampling_metadata: bool = False):
"""Initialize the `UnweightedReservoirSamplingFactory`.
Args:
sample_size: An integer specifying the number of clients sampled (by
reservoir sampling algorithm). Values from the sampled clients are
collected at the server (see the class documentation for details).
return_sampling_metadata: If True, the `result` property in the returned
`tff.templates.MeasuredProcessOutput` object contains a dictionary of
sampled values and other sampling metadata (such as random values
generated during reservoir sampling). Otherwise, it only contains the
sampled values.
Raises:
TypeError: If any argument type mismatches.
ValueError: If `sample_size` is not positive.
"""
py_typecheck.check_type(sample_size, int)
py_typecheck.check_type(return_sampling_metadata, bool)
if sample_size <= 0:
raise ValueError('`sample_size` must be positive.')
self._sample_size = sample_size
self._return_sampling_metadata = return_sampling_metadata
def create(
self,
value_type: computation_types.Type,
) -> aggregation_process.AggregationProcess:
if not type_analysis.is_structure_of_tensors(value_type):
raise TypeError(
f'`value_type` must be a structure of tensors, got a {value_type!r}.'
)
@federated_computation.federated_computation()
def init_fn():
# Empty/null state, nothing is tracked across invocations.
return intrinsics.federated_value((), placements.SERVER)
@federated_computation.federated_computation(
computation_types.FederatedType((), placements.SERVER),
computation_types.FederatedType(value_type, placements.CLIENTS),
)
def next_fn(unused_state, value):
# Empty tuple is the `None` of TFF.
empty_tuple = intrinsics.federated_value((), placements.SERVER)
non_finite_leaves_counts = intrinsics.federated_sum(
intrinsics.federated_map(
_build_check_non_finite_leaves_computation(value_type), value
)
)
initial_reservoir = build_initial_sample_reservoir(value_type)
sample_value = _build_sample_value_computation(
value_type, self._sample_size
)
merge_samples = build_merge_samples_computation(
value_type, self._sample_size
)
finalize_sample = _build_finalize_sample_computation(
value_type, self._return_sampling_metadata
)
samples = intrinsics.federated_aggregate(
value,
zero=initial_reservoir,
accumulate=sample_value,
merge=merge_samples,
report=finalize_sample,
)
return measured_process.MeasuredProcessOutput(
state=empty_tuple,
result=samples,
measurements=non_finite_leaves_counts,
)
return aggregation_process.AggregationProcess(init_fn, next_fn)