-
Notifications
You must be signed in to change notification settings - Fork 579
/
async_execution_context.py
231 lines (198 loc) · 8.85 KB
/
async_execution_context.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A context for execution based on an embedded executor instance."""
import asyncio
from collections.abc import Callable
import contextlib
from typing import Generic, Optional, TypeVar
from tensorflow_federated.python.common_libs import py_typecheck
from tensorflow_federated.python.common_libs import retrying
from tensorflow_federated.python.common_libs import structure
from tensorflow_federated.python.common_libs import tracing
from tensorflow_federated.python.core.impl.computation import computation_base
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.execution_contexts import compiler_pipeline
from tensorflow_federated.python.core.impl.executors import cardinalities_utils
from tensorflow_federated.python.core.impl.executors import executor_base
from tensorflow_federated.python.core.impl.executors import executor_factory
from tensorflow_federated.python.core.impl.executors import executor_value_base
from tensorflow_federated.python.core.impl.executors import executors_errors
from tensorflow_federated.python.core.impl.executors import ingestable_base
from tensorflow_federated.python.core.impl.types import computation_types
from tensorflow_federated.python.core.impl.types import type_conversions
from tensorflow_federated.python.core.impl.types import typed_object
_Computation = TypeVar('_Computation', bound=computation_base.Computation)
def _is_retryable_error(exception):
return isinstance(exception, executors_errors.RetryableError)
class AsyncExecutionContextValue(typed_object.TypedObject):
"""Wrapper class for values produced by `ExecutionContext`."""
def __init__(self, value, type_spec):
py_typecheck.check_type(type_spec, computation_types.Type)
self._value = value
self._type_spec = type_spec
@property
def type_signature(self):
return self._type_spec
@property
def value(self):
return self._value
async def _ingest(executor, val, type_spec):
"""A coroutine that handles ingestion.
Args:
executor: An instance of `executor_base.Executor`.
val: The first argument to `AsyncExecutionContext.ingest()`.
type_spec: The second argument to `AsyncExecutionContext.ingest()`.
Returns:
The result of the ingestion.
Raises:
TypeError: If the `val` is not a value of type `type_spec`.
"""
if isinstance(val, executor_value_base.ExecutorValue):
return val
elif isinstance(val, ingestable_base.Ingestable):
val_type = val.type_signature
py_typecheck.check_type(val_type, computation_types.Type)
type_spec.check_assignable_from(val_type)
return await val.ingest(executor)
elif isinstance(val, structure.Struct) and not isinstance(
type_spec, computation_types.FederatedType
):
if not isinstance(type_spec, computation_types.StructType):
raise ValueError(f'Expected a `tff.StructType`, found {type_spec}.')
v_elem = structure.to_elements(val)
t_elem = structure.to_elements(type_spec)
if len(v_elem) != len(t_elem):
raise TypeError(
'Value {} does not match type {}: mismatching tuple length.'.format(
val, type_spec
)
)
for (vk, _), (tk, _) in zip(v_elem, t_elem):
if vk not in [tk, None]:
raise TypeError(
'Value {} does not match type {}: mismatching tuple element '
'names {} vs. {}.'.format(val, type_spec, vk, tk)
)
ingested = []
for (_, v), (_, t) in zip(v_elem, t_elem):
ingested.append(_ingest(executor, v, t))
ingested = await asyncio.gather(*ingested)
return await executor.create_struct(
structure.Struct(
(name, val) for (name, _), val in zip(t_elem, ingested)
)
)
else:
return await executor.create_value(val, type_spec)
async def _invoke(executor, comp, arg, result_type: computation_types.Type):
"""A coroutine that handles invocation.
Args:
executor: An instance of `executor_base.Executor`.
comp: The first argument to `AsyncExecutionContext.invoke()`.
arg: The optional second argument to `AsyncExecutionContext.invoke()`.
result_type: The type signature of the result. This is used to convert the
execution result into the proper container types.
Returns:
The result of the invocation.
"""
if arg is not None:
py_typecheck.check_type(arg, executor_value_base.ExecutorValue)
comp = await executor.create_value(comp, comp.type_signature)
result = await executor.create_call(comp, arg)
py_typecheck.check_type(result, executor_value_base.ExecutorValue)
result_value = await result.compute()
return type_conversions.type_to_py_container(result_value, result_type)
class AsyncExecutionContext(context_base.AsyncContext, Generic[_Computation]):
"""An asynchronous execution context backed by an `executor_base.Executor`.
This context's `ingest` and `invoke` methods return Python coroutine objects
which represent the actual work of ingestion and invocation in the backing
executor.
This context will support concurrent invocation of multiple computations if
their arguments have the same cardinalities.
"""
def __init__(
self,
executor_fn: executor_factory.ExecutorFactory,
compiler_fn: Optional[Callable[[_Computation], object]] = None,
*,
cardinality_inference_fn: cardinalities_utils.CardinalityInferenceFnType = cardinalities_utils.infer_cardinalities,
):
"""Initializes an execution context.
Args:
executor_fn: Instance of `executor_factory.ExecutorFactory`.
compiler_fn: A Python function that will be used to compile a computation.
cardinality_inference_fn: A Python function specifying how to infer
cardinalities from arguments (and their associated types). The value
returned by this function will be passed to the `create_executor` method
of `executor_fn` to construct a `tff.framework.Executor` instance.
"""
super().__init__()
py_typecheck.check_type(executor_fn, executor_factory.ExecutorFactory)
self._executor_factory = executor_fn
if compiler_fn is not None:
self._compiler_pipeline = compiler_pipeline.CompilerPipeline(compiler_fn)
else:
self._compiler_pipeline = None
self._cardinality_inference_fn = cardinality_inference_fn
@contextlib.contextmanager
def _reset_factory_on_error(self, ex_factory, cardinalities):
try:
# We pass a copy down to prevent the caller from mutating.
yield ex_factory.create_executor({**cardinalities})
except Exception:
ex_factory.clean_up_executor({**cardinalities})
raise
@property
def executor_factory(self) -> executor_factory.ExecutorFactory:
return self._executor_factory
@retrying.retry(
retry_on_exception_filter=_is_retryable_error,
wait_max_ms=30 * 1000,
wait_multiplier=2,
)
async def invoke(self, comp, arg):
if asyncio.iscoroutine(arg):
# Awaiting if we are passed a coro allows us to install and use the async
# context in conjunction with ConcreteComputations' implementation of
# __call__.
arg = await arg
if not isinstance(comp.type_signature, computation_types.FunctionType):
raise ValueError(
f'Expected a `tff.FunctionType`, found {comp.type_signature}.'
)
# Save the type signature before compiling. Compilation currently loses
# container types, so we must remember them here so that they can be
# restored in the output.
result_type = comp.type_signature.result
if self._compiler_pipeline is not None:
with tracing.span('ExecutionContext', 'Compile', span=True):
comp = self._compiler_pipeline.compile(comp)
with tracing.span('ExecutionContext', 'Invoke', span=True):
if arg is not None:
cardinalities = self._cardinality_inference_fn(
arg, comp.type_signature.parameter
)
else:
cardinalities = {}
with self._reset_factory_on_error(
self._executor_factory, cardinalities
) as executor:
py_typecheck.check_type(executor, executor_base.Executor)
if arg is not None:
arg = await tracing.wrap_coroutine_in_current_trace_context(
_ingest(executor, arg, comp.type_signature.parameter)
)
return await tracing.wrap_coroutine_in_current_trace_context(
_invoke(executor, comp, arg, result_type)
)