/
iblt_factory.py
283 lines (255 loc) · 10.7 KB
/
iblt_factory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
# Copyright 2022, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory for string aggregation using IBLT."""
import collections
from typing import Any, Optional
import attrs
import numpy as np
import tensorflow as tf
from tensorflow_federated.python.aggregators import factory
from tensorflow_federated.python.aggregators import sum_factory
from tensorflow_federated.python.analytics import data_processing
from tensorflow_federated.python.analytics.heavy_hitters.iblt import chunkers
from tensorflow_federated.python.analytics.heavy_hitters.iblt import iblt_tensor
from tensorflow_federated.python.core.environments.tensorflow_frontend import tensorflow_computation
from tensorflow_federated.python.core.impl.federated_context import federated_computation
from tensorflow_federated.python.core.impl.federated_context import intrinsics
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.templates import aggregation_process
from tensorflow_federated.python.core.templates import measured_process
# Convenience Aliases
_CharacterEncoding = chunkers.CharacterEncoding
DATASET_KEY = 'key'
DATASET_VALUE = 'value'
@attrs.define(eq=False, frozen=True)
class ServerOutput:
output_strings: Any
string_values: Any
num_not_decoded: Any
@tf.function
def _parse_client_dict(
dataset: tf.data.Dataset, string_max_bytes: int
) -> tuple[tf.Tensor, tf.Tensor]:
"""Parses the dictionary in the input `dataset` to key and value lists.
Args:
dataset: A `tf.data.Dataset` that yields `OrderedDict`. In each
`OrderedDict` there are two key, value pairs: `DATASET_KEY`: A string
tensor, and `DATASET_VALUE`: A rank 1 int64 tensor representing the value
associate with the string.
string_max_bytes: The maximum length of the strings in bytes. If any string
is longer than `string_max_bytes`, a `ValueError` will be raised.
Returns:
input_strings: A rank 1 `tf.Tensor` containing the list of strings in
`dataset`.
string_values: A rank 2 `tf.Tensor` containing the values of
`input_strings`.
Raises:
ValueError: If any string in `dataset` is longer than string_max_bytes.
"""
parsed_dict = data_processing.to_stacked_tensor(dataset)
input_strings = parsed_dict[DATASET_KEY]
string_values = parsed_dict[DATASET_VALUE]
tf.debugging.Assert(
tf.math.logical_not(
tf.math.reduce_any(
tf.greater(tf.strings.length(input_strings), string_max_bytes)
)
),
data=[
(
'IbltFactory: Input strings must be truncated to'
f' {string_max_bytes=}'
),
input_strings,
],
name='CHECK_STRING_LENGTH',
)
return input_strings, string_values
class IbltFactory(factory.UnweightedAggregationFactory):
"""Factory for string and values aggregation by IBLT."""
def __init__(
self,
*,
capacity: int,
string_max_bytes: int,
encoding: _CharacterEncoding = _CharacterEncoding.UTF8,
repetitions: int,
seed: int = 0,
sketch_agg_factory: Optional[factory.UnweightedAggregationFactory] = None,
value_tensor_agg_factory: Optional[
factory.UnweightedAggregationFactory
] = None,
) -> None:
"""Initializes IbltFactory.
Args:
capacity: The capacity of the IBLT sketch. Must be positive.
string_max_bytes: The maximum length in bytes of a string in the IBLT.
Must be positive.
encoding: The character encoding of the string data to encode. For
non-character binary data or strings with unknown encoding, specify
`CharacterEncoding.UNKNOWN`.
repetitions: The number of repetitions in IBLT data structure (must be >=
3). Must be at least `3`.
seed: An integer seed for hash functions. Defaults to 0.
sketch_agg_factory: (Optional) A `UnweightedAggregationFactory` specifying
the value aggregation to sum IBLT sketches. Defaults to
`tff.aggregators.SumFactory`. If `sketch_agg_factory` is set to a
`tff.aggregators.SecureSumFactory`, then the `upper_bound_threshold`
should be at least 2 ** 32 - 1.
value_tensor_agg_factory: (Optional) A `UnweightedAggregationFactory`
specifying the value aggregation to sum value tensors. Defaults to
`tff.aggregators.SumFactory`. Note that when using `sketch_agg_factory`
is set to a `tff.aggregators.SecureSumFactory`, the value to be summed
might be clipped depends on the choices of `upper_bound_threshold` and
`lower_bound_threshold` parameters in `SecureSumFactory`.
Raises:
ValueError: if parameters don't meet expectations.
"""
if capacity < 1:
raise ValueError(f'capacity should be at least 1, got {capacity}')
if string_max_bytes < 1:
raise ValueError(
f'string_max_bytes should be at least 1, got {string_max_bytes}'
)
if repetitions < 3:
raise ValueError(f'repetitions should be at least 3, got {repetitions}')
self._sketch_agg_factory = (
sum_factory.SumFactory()
if sketch_agg_factory is None
else sketch_agg_factory
)
self._value_tensor_agg_factory = (
sum_factory.SumFactory()
if value_tensor_agg_factory is None
else value_tensor_agg_factory
)
self._capacity = capacity
self._string_max_bytes = string_max_bytes
self._encoding = encoding
self._repetitions = repetitions
self._seed = seed
def create(
self, value_type: computation_types.SequenceType
) -> aggregation_process.AggregationProcess: # pytype: disable=signature-mismatch
"""Creates an AggregationProcess using IBLT to aggregate strings.
Args:
value_type: A `tff.SequenceType` representing the type of the input
dataset, must be compatible with the following `tff.Type`:
tff.SequenceType(collections.OrderedDict([ (DATASET_KEY, np.str_),
(DATASET_VALUE, tff.TensorType(shape=[None], dtype=np.int64)), ]))
Raises:
ValueError: If `value_type` is not as expected.
Returns:
A `tff.templates.AggregationProcess` to aggregate strings and values
associate with the strings.
"""
expected_value_type = computation_types.SequenceType(
collections.OrderedDict([
(DATASET_KEY, np.str_),
(
DATASET_VALUE,
computation_types.TensorType(shape=[None], dtype=np.int64),
),
])
)
if not expected_value_type.is_assignable_from(value_type):
raise ValueError(
'value_shape must be compatible with '
f'{expected_value_type}. Found {value_type} instead.'
)
self._value_shape = tuple(value_type.element[DATASET_VALUE].shape) # pytype: disable=unsupported-operands
@tensorflow_computation.tf_computation(value_type)
def encode_iblt(dataset):
"""The TF computation to compute the IBLT frequency sketches."""
input_strings, string_values = _parse_client_dict(
dataset, self._string_max_bytes
)
iblt_encoder = iblt_tensor.IbltTensorEncoder(
capacity=self._capacity,
string_max_bytes=self._string_max_bytes,
encoding=self._encoding,
repetitions=self._repetitions,
value_shape=self._value_shape,
seed=self._seed,
)
return iblt_encoder.compute_iblt(input_strings, string_values)
@tensorflow_computation.tf_computation(encode_iblt.type_signature.result)
@tf.function
def decode_iblt(sketch, value_tensor):
"""The TF computation to decode the strings and values from IBLT."""
iblt_decoder = iblt_tensor.IbltTensorDecoder(
iblt=sketch,
iblt_values=value_tensor,
value_shape=self._value_shape,
capacity=self._capacity,
string_max_bytes=self._string_max_bytes,
encoding=self._encoding,
repetitions=self._repetitions,
seed=self._seed,
)
(output_strings, _, string_values, num_not_decoded) = (
iblt_decoder.get_freq_estimates_tf()
)
return (output_strings, string_values, num_not_decoded)
inner_aggregator_sketch = self._sketch_agg_factory.create(
encode_iblt.type_signature.result[0]
)
inner_aggregator_value_tensor = self._value_tensor_agg_factory.create(
encode_iblt.type_signature.result[1]
)
@federated_computation.federated_computation
def init_fn():
sketch_state = inner_aggregator_sketch.initialize()
value_tensor_state = inner_aggregator_value_tensor.initialize()
return intrinsics.federated_zip((sketch_state, value_tensor_state))
@federated_computation.federated_computation(
init_fn.type_signature.result,
computation_types.FederatedType(value_type, placements.CLIENTS),
)
def next_fn(state, dataset):
sketch_state, value_tensor_state = state
sketch, value_tensor = intrinsics.federated_map(encode_iblt, dataset)
sketch_output = inner_aggregator_sketch.next(sketch_state, sketch)
value_tensor_output = inner_aggregator_value_tensor.next(
value_tensor_state, value_tensor
)
summed_sketch = sketch_output.result
summed_value_tensor = value_tensor_output.result
(output_strings, string_values, num_not_decoded) = (
intrinsics.federated_map(
decode_iblt, (summed_sketch, summed_value_tensor)
)
)
result = intrinsics.federated_zip(
ServerOutput(
output_strings=output_strings,
string_values=string_values,
num_not_decoded=num_not_decoded,
)
)
updated_state = intrinsics.federated_zip(
(sketch_output.state, value_tensor_output.state)
)
updated_measurements = intrinsics.federated_zip(
collections.OrderedDict(
num_not_decoded=num_not_decoded,
sketch=sketch_output.measurements,
value_tensor=value_tensor_output.measurements,
)
)
return measured_process.MeasuredProcessOutput(
updated_state, result, updated_measurements
)
return aggregation_process.AggregationProcess(init_fn, next_fn)