-
-
Notifications
You must be signed in to change notification settings - Fork 606
/
scheduler_test.py
332 lines (251 loc) · 10.4 KB
/
scheduler_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
# Copyright 2015 Pants project contributors (see CONTRIBUTORS.md).
# Licensed under the Apache License, Version 2.0 (see LICENSE).
import re
from contextlib import contextmanager
from dataclasses import dataclass
from textwrap import dedent
from typing import Any, FrozenSet
from pants.engine.internals.scheduler import ExecutionError
from pants.engine.rules import RootRule, rule
from pants.engine.selectors import Get, Params
from pants.engine.unions import UnionRule, union
from pants.testutil.engine.util import assert_equal_with_printing, remove_locations_from_traceback
from pants.testutil.test_base import TestBase
@dataclass(frozen=True)
class A:
pass
@dataclass(frozen=True)
class B:
pass
def fn_raises(x):
raise Exception(f"An exception for {type(x).__name__}")
@rule(desc="Nested raise")
def nested_raise(x: B) -> A: # type: ignore[return]
fn_raises(x)
@rule
def consumes_a_and_b(a: A, b: B) -> str:
return str(f"{a} and {b}")
@dataclass(frozen=True)
class C:
pass
@rule
def transitive_b_c(c: C) -> B:
return B()
@dataclass(frozen=True)
class D:
b: B
@rule
async def transitive_coroutine_rule(c: C) -> D:
b = await Get(B, C, c)
return D(b)
@union
class UnionBase:
pass
@union
class UnionWithNonMemberErrorMsg:
@staticmethod
def non_member_error_message(subject):
return f"specific error message for {type(subject).__name__} instance"
class UnionWrapper:
def __init__(self, inner):
self.inner = inner
class UnionA:
def a(self):
return A()
@rule
def select_union_a(union_a: UnionA) -> A:
return union_a.a() # type: ignore[no-any-return]
class UnionB:
def a(self):
return A()
@rule
def select_union_b(union_b: UnionB) -> A:
return union_b.a() # type: ignore[no-any-return]
# TODO: add GetMulti testing for unions!
@rule
async def a_union_test(union_wrapper: UnionWrapper) -> A:
union_a = await Get(A, UnionBase, union_wrapper.inner)
return union_a
class UnionX:
pass
@rule
async def error_msg_test_rule(union_wrapper: UnionWrapper) -> UnionX:
# NB: We install a UnionRule to make UnionWrapper a member of this union, but then we pass the
# inner value, which is _not_ registered.
_ = await Get(A, UnionWithNonMemberErrorMsg, union_wrapper.inner)
raise AssertionError("The statement above this one should have failed!")
class BooleanDeps(FrozenSet[bool]):
pass
@rule
async def boolean_cycle(key: bool) -> BooleanDeps:
"""A rule with exactly two instances (bool == two keys), which depend on one another weakly."""
deps = {key}
dep = await Get(BooleanDeps, bool, not key, weak=True)
if dep is not None:
deps.update(dep)
return BooleanDeps(deps)
class TypeCheckFailWrapper:
"""This object wraps another object which will be used to demonstrate a type check failure when
the engine processes an `await Get(...)` statement."""
def __init__(self, inner):
self.inner = inner
@rule
async def a_typecheck_fail_test(wrapper: TypeCheckFailWrapper) -> A:
# This `await` would use the `nested_raise` rule, but it won't get to the point of raising since
# the type check will fail at the Get.
_ = await Get(A, B, wrapper.inner) # noqa: F841
return A()
@dataclass(frozen=True)
class CollectionType:
# NB: We pass an unhashable type when we want this to fail at the root, and a hashable type
# when we'd like it to succeed.
items: Any
@rule
async def c_unhashable(_: CollectionType) -> C:
# This `await` would use the `nested_raise` rule, but it won't get to the point of raising since
# the hashability check will fail.
_result = await Get(A, B, list()) # noqa: F841
return C()
@rule
def boolean_and_int(i: int, b: bool) -> A:
return A()
@contextmanager
def assert_execution_error(test_case, expected_msg):
with test_case.assertRaises(ExecutionError) as cm:
yield
test_case.assertIn(expected_msg, remove_locations_from_traceback(str(cm.exception)))
class SchedulerTest(TestBase):
@classmethod
def rules(cls):
return super().rules() + [
RootRule(A),
# B is both a RootRule and an intermediate product here.
RootRule(B),
RootRule(C),
RootRule(UnionX),
error_msg_test_rule,
consumes_a_and_b,
transitive_b_c,
transitive_coroutine_rule,
RootRule(UnionWrapper),
UnionRule(UnionBase, UnionA),
UnionRule(UnionWithNonMemberErrorMsg, UnionWrapper),
RootRule(UnionA),
select_union_a,
UnionRule(union_base=UnionBase, union_member=UnionB),
RootRule(UnionB),
select_union_b,
a_union_test,
boolean_cycle,
boolean_and_int,
RootRule(int),
RootRule(bool),
]
def test_use_params(self):
# Confirm that we can pass in Params in order to provide multiple inputs to an execution.
a, b = A(), B()
result_str = self.request_single_product(str, Params(a, b))
self.assertEqual(result_str, consumes_a_and_b(a, b))
# And confirm that a superset of Params is also accepted.
result_str = self.request_single_product(str, Params(a, b, self))
self.assertEqual(result_str, consumes_a_and_b(a, b))
# But not a subset.
expected_msg = "No installed @rules can compute {} given input Params(A), but".format(
str.__name__
)
with self.assertRaisesRegex(Exception, re.escape(expected_msg)):
self.request_single_product(str, Params(a))
def test_transitive_params(self):
# Test that C can be provided and implicitly converted into a B with transitive_b_c() to satisfy
# the selectors of consumes_a_and_b().
a, c = A(), C()
result_str = self.request_single_product(str, Params(a, c))
self.assertEqual(
remove_locations_from_traceback(result_str),
remove_locations_from_traceback(consumes_a_and_b(a, transitive_b_c(c))),
)
# Test that an inner Get in transitive_coroutine_rule() is able to resolve B from C due to the
# existence of transitive_b_c().
with self.assertDoesNotRaise():
_ = self.request_single_product(D, Params(c))
def test_consumed_types(self):
assert {A, B, C, str} == set(
self.scheduler.scheduler.rule_graph_consumed_types([A, C], str)
)
def test_strict_equals(self):
# With the default implementation of `__eq__` for boolean and int, `1 == True`. But in the
# engine that behavior would be surprising, and would cause both of these Params to intern
# to the same value, triggering an error. Instead, the engine additionally includes the
# type of a value in equality.
assert A() == self.request_single_product(A, Params(1, True))
def test_weak_gets(self):
assert {True, False} == set(self.request_single_product(BooleanDeps, True))
assert {True, False} == set(self.request_single_product(BooleanDeps, False))
@contextmanager
def _assert_execution_error(self, expected_msg):
with assert_execution_error(self, expected_msg):
yield
def test_union_rules(self):
with self.assertDoesNotRaise():
_ = self.request_single_product(A, Params(UnionWrapper(UnionA())))
with self.assertDoesNotRaise():
_ = self.request_single_product(A, Params(UnionWrapper(UnionB())))
# Fails due to no union relationship from A -> UnionBase.
with self._assert_execution_error("Type A is not a member of the UnionBase @union"):
self.request_single_product(A, Params(UnionWrapper(A())))
def test_union_rules_no_docstring(self):
with self._assert_execution_error("specific error message for UnionA instance"):
self.request_single_product(UnionX, Params(UnionWrapper(UnionA())))
class SchedulerWithNestedRaiseTest(TestBase):
@classmethod
def rules(cls):
return super().rules() + [
RootRule(B),
RootRule(TypeCheckFailWrapper),
RootRule(CollectionType),
a_typecheck_fail_test,
c_unhashable,
nested_raise,
]
def test_get_type_match_failure(self):
"""Test that Get(...)s are now type-checked during rule execution, to allow for union
types."""
with self.assertRaises(ExecutionError) as cm:
# `a_typecheck_fail_test` above expects `wrapper.inner` to be a `B`.
self.request_single_product(A, Params(TypeCheckFailWrapper(A())))
expected_regex = "WithDeps.*did not declare a dependency on JustGet"
self.assertRegex(str(cm.exception), expected_regex)
def test_unhashable_root_params_failure(self):
"""Test that unhashable root params result in a structured error."""
# This will fail at the rust boundary, before even entering the engine.
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
self.request_single_product(C, Params(CollectionType([1, 2, 3])))
def test_unhashable_get_params_failure(self):
"""Test that unhashable Get(...) params result in a structured error."""
# This will fail inside of `c_unhashable_dataclass`.
with self.assertRaisesRegex(ExecutionError, "unhashable type: 'list'"):
self.request_single_product(C, Params(CollectionType(tuple())))
def test_trace_includes_rule_exception_traceback(self):
# Execute a request that will trigger the nested raise, and then directly inspect its trace.
request = self.scheduler.execution_request([A], [B()])
_, throws = self.scheduler.execute(request)
with self.assertRaises(ExecutionError) as cm:
self.scheduler._raise_on_error([t for _, t in throws])
trace = remove_locations_from_traceback(str(cm.exception))
assert_equal_with_printing(
self,
dedent(
"""\
1 Exception encountered:
Engine traceback:
in Nested raise
Traceback (most recent call last):
File LOCATION-INFO, in nested_raise
fn_raises(x)
File LOCATION-INFO, in fn_raises
raise Exception(f"An exception for {type(x).__name__}")
Exception: An exception for B
"""
),
trace,
)