/
native_platform.py
289 lines (245 loc) · 10.8 KB
/
native_platform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# Copyright 2022, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A federated platform implemented using native TFF components."""
import asyncio
from collections.abc import Awaitable, Callable
import functools
import typing
from typing import Optional, TypeVar, Union
from tensorflow_federated.python.common_libs import async_utils
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.execution_contexts import async_execution_context
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import placements
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.program import federated_context
from tensorflow_federated.python.program import structure_utils
from tensorflow_federated.python.program import value_reference
_T = TypeVar('_T')
# This type defines values of type `_T` nested in a structure of
# `tff.structure.Struct`'s.
# TODO: b/232433269 - Update `tff.structure.Struct` to be able to define nested
# homogeneous structures of `tff.structure.Struct`s.
_StructStructure = Union[
_T,
structure.Struct,
]
class AwaitableValueReference(value_reference.MaterializableValueReference):
"""A `tff.program.MaterializableValueReference` backed by a coroutine function."""
def __init__(
self,
fn: Callable[[], Awaitable[value_reference.MaterializedValue]],
type_signature: value_reference.MaterializableTypeSignature,
):
"""Returns an initialized `tff.program.AwaitableValueReference`.
Args:
fn: A function that returns an `Awaitable` representing the referenced
value.
type_signature: The `tff.Type` of this object.
"""
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}.'
)
py_typecheck.check_type(
type_signature,
typing.get_args(value_reference.MaterializableTypeSignature),
)
self._fn = fn
self._type_signature = type_signature
self._value = None
@property
def type_signature(self) -> value_reference.MaterializableTypeSignature:
"""The `tff.TensorType` of this object."""
return self._type_signature
async def get_value(self) -> value_reference.MaterializedValue:
"""Returns the referenced value as a numpy scalar or array."""
if self._value is None:
self._value = await self._fn()
return self._value
def __eq__(self, other: object) -> bool:
if self is other:
return True
elif not isinstance(other, AwaitableValueReference):
return NotImplemented
return (
self._type_signature == other._type_signature and self._fn == other._fn
)
def _wrap_in_shared_awaitable(
fn: Callable[..., Awaitable[object]]
) -> Callable[..., async_utils.SharedAwaitable]:
"""Wraps the returned awaitable in a `tff.async_utils.SharedAwaitable`.
Args:
fn: A function that returns an `Awaitable`.
Returns:
A function that returns a `tff.async_utils.SharedAwaitable`
"""
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}.'
)
@functools.cache
def wrapper(*args: object, **kwargs: object) -> async_utils.SharedAwaitable:
awaitable = fn(*args, **kwargs)
return async_utils.SharedAwaitable(awaitable)
return wrapper
def _create_structure_of_awaitable_references(
fn: Callable[[], Awaitable[value_reference.MaterializableStructure]],
type_signature: computation_types.Type,
) -> _StructStructure[AwaitableValueReference]:
"""Returns a structure of `tff.program.AwaitableValueReference`s.
Args:
fn: A function that returns an `Awaitable` used to create the structure of
`tff.program.AwaitableValueReference`s.
type_signature: The `tff.Type` of the value returned by `coro_fn`; must
contain only structures, server-placed values, or tensors.
Raises:
NotImplementedError: If `type_signature` contains an unexpected type.
"""
if not callable(fn):
raise TypeError(
f'Expected a function that returns an `Awaitable`, found {type(fn)}.'
)
py_typecheck.check_type(type_signature, computation_types.Type)
# A `async_utils.SharedAwaitable` is required to materialize structures of
# values multiple times. This happens when a value is released using multiple
# `tff.program.ReleaseManager`s.
fn = _wrap_in_shared_awaitable(fn)
if isinstance(type_signature, computation_types.StructType):
async def _to_structure(
fn: Callable[[], Awaitable[value_reference.MaterializableStructure]]
) -> structure.Struct:
value = await fn()
return structure.from_container(value)
fn = functools.partial(_to_structure, fn)
# A `tff.async_utils.SharedAwaitable` is required to materialize structures
# of values concurrently. This happens when the structure is flattened and
# the `tff.program.AwaitableValueReference`s are materialized concurrently,
# see `tff.program.materialize_value` for an example.
fn = _wrap_in_shared_awaitable(fn)
async def _get_item(
fn: Callable[[], Awaitable[value_reference.MaterializableStructure]],
index: int,
) -> value_reference.MaterializedValue:
value = await fn()
return value[index]
elements = []
element_types = structure.iter_elements(type_signature)
for index, (name, element_type) in enumerate(element_types):
element_fn = functools.partial(_get_item, fn, index)
element = _create_structure_of_awaitable_references(
element_fn, element_type
)
elements.append((name, element))
return structure.Struct(elements)
elif (
isinstance(type_signature, computation_types.FederatedType)
and type_signature.placement == placements.SERVER
):
return _create_structure_of_awaitable_references(fn, type_signature.member)
elif isinstance(type_signature, computation_types.SequenceType):
return AwaitableValueReference(fn, type_signature)
elif isinstance(type_signature, computation_types.TensorType):
return AwaitableValueReference(fn, type_signature)
else:
raise NotImplementedError(f'Unexpected type found: {type_signature}.')
async def _materialize_structure_of_value_references(
value: value_reference.MaterializableStructure,
type_signature: computation_types.Type,
) -> _StructStructure[value_reference.MaterializedValue]:
"""Returns a structure of materialized values."""
py_typecheck.check_type(type_signature, computation_types.Type)
async def _materialize(
value: value_reference.MaterializableValue,
) -> value_reference.MaterializedValue:
if isinstance(value, value_reference.MaterializableValueReference):
return await value.get_value()
else:
return value
if isinstance(type_signature, computation_types.StructType):
value = structure.from_container(value)
element_types = list(structure.iter_elements(type_signature))
element_awaitables = [
_materialize_structure_of_value_references(v, t)
for v, (_, t) in zip(value, element_types)
]
elements = await asyncio.gather(*element_awaitables)
elements = [(n, v) for v, (n, _) in zip(elements, element_types)]
return structure.Struct(elements)
elif isinstance(type_signature, computation_types.FederatedType):
return await _materialize_structure_of_value_references(
value, type_signature.member
)
elif isinstance(type_signature, computation_types.SequenceType):
return await _materialize(value)
elif isinstance(type_signature, computation_types.TensorType):
return await _materialize(value)
else:
raise NotImplementedError(f'Unexpected type found: {type_signature}.')
class NativeFederatedContext(federated_context.FederatedContext):
"""A `tff.program.FederatedContext` backed by an execution context."""
def __init__(self, context: async_execution_context.AsyncExecutionContext):
"""Returns an initialized `tff.program.NativeFederatedContext`.
Args:
context: An `tff.framework.AsyncExecutionContext`.
"""
py_typecheck.check_type(
context, async_execution_context.AsyncExecutionContext
)
self._context = context
def invoke(
self,
comp: computation_base.Computation,
arg: Optional[federated_context.ComputationArgValue],
) -> structure_utils.Structure[AwaitableValueReference]:
"""Invokes the `comp` with the argument `arg`.
Args:
comp: The `tff.Computation` being invoked.
arg: The optional argument of `comp`; server-placed values must be
represented by `tff.program.MaterializableStructure`, and client-placed
values must be represented by structures of values returned by a
`tff.program.FederatedDataSourceIterator`.
Returns:
The result of invocation; a structure of
`tff.program.MaterializableValueReference`.
Raises:
ValueError: If the result type of the invoked computation does not contain
only structures, server-placed values, or tensors.
Raises:
ValueError: If the result type of `comp` does not contain only structures,
server-placed values, or tensors.
"""
py_typecheck.check_type(comp, computation_base.Computation)
result_type = comp.type_signature.result
if not federated_context.contains_only_server_placed_data(result_type):
raise ValueError(
'Expected the result type of `comp` to contain only structures, '
f'server-placed values, or tensors, found {result_type}.'
)
async def _invoke(
context: async_execution_context.AsyncExecutionContext,
comp: computation_base.Computation,
arg: value_reference.MaterializableStructure,
) -> value_reference.MaterializedStructure:
if comp.type_signature.parameter is not None:
arg = await _materialize_structure_of_value_references(
arg, comp.type_signature.parameter
)
return await context.invoke(comp, arg)
coro_fn = functools.partial(_invoke, self._context, comp, arg)
result = _create_structure_of_awaitable_references(coro_fn, result_type)
result = type_conversions.type_to_py_container(result, result_type)
return result