/
context_stack_test_utils.py
141 lines (110 loc) · 4.06 KB
/
context_stack_test_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
# Copyright 2019, The TensorFlow Federated Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for testing context stacks."""
import asyncio
from collections.abc import Callable, Iterable
import contextlib
import functools
from typing import Optional, Union
from absl.testing import parameterized
from tensorflow_federated.python.core.impl.context_stack import context_base
from tensorflow_federated.python.core.impl.context_stack import context_stack_impl
_Context = Union[context_base.AsyncContext, context_base.SyncContext]
_ContextFactory = Callable[[], _Context]
_EnvironmentFactory = Callable[
[], Iterable[contextlib.AbstractContextManager[None]]
]
class TestContext(context_base.SyncContext):
"""A test context."""
def invoke(self, comp, arg):
return NotImplementedError
@contextlib.contextmanager
def test_environment():
yield None
def with_context(
context_fn: _ContextFactory,
environment_fn: Optional[_EnvironmentFactory] = None,
):
"""Returns a decorator for running a test in a context.
Args:
context_fn: A `Callable` that constructs a `tff.framework.AsyncContext` or
`tff.framework.SyncContext` to install beore invoking the decorated
function.
environment_fn: A `Callable` that constructs a list of
`contextlib.AbstractContextManager` to enter before invoking the decorated
function.
"""
def decorator(fn):
@contextlib.contextmanager
def install_context(
context_fn: _ContextFactory,
environment_fn: Optional[_EnvironmentFactory] = None,
):
context = context_fn()
with context_stack_impl.context_stack.install(context):
if environment_fn is not None:
with contextlib.ExitStack() as stack:
context_managers = environment_fn()
for context_manager in context_managers:
stack.enter_context(context_manager)
yield
else:
yield
if asyncio.iscoroutinefunction(fn):
@functools.wraps(fn)
async def wrapper(*args, **kwargs):
with install_context(context_fn, environment_fn):
return await fn(*args, **kwargs)
else:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with install_context(context_fn, environment_fn):
return fn(*args, **kwargs)
return wrapper
return decorator
def with_contexts(*named_contexts):
"""Returns a decorator for parameterizing a test by a context.
Args:
*named_contexts: Named parameters used to construct the `with_context`
decorator; either a single iterable, or a list of `tuple`s or `dict`s.
Raises:
ValueError: If no named contexts are passed to the decorator.
"""
if not named_contexts:
raise ValueError('Expected at least one named parameter, found none.')
def decorator(fn):
if asyncio.iscoroutinefunction(fn):
@functools.wraps(fn)
@parameterized.named_parameters(*named_contexts)
async def wrapper(
self,
context_fn: _ContextFactory,
environment_fn: Optional[_EnvironmentFactory] = None,
):
decorator = with_context(context_fn, environment_fn)
decorated_fn = decorator(fn)
await decorated_fn(self)
else:
@functools.wraps(fn)
@parameterized.named_parameters(*named_contexts)
def wrapper(
self,
context_fn: _ContextFactory,
environment_fn: Optional[_EnvironmentFactory] = None,
):
decorator = with_context(context_fn, environment_fn)
decorated_fn = decorator(fn)
decorated_fn(self)
return wrapper
return decorator