This repository was archived by the owner on Aug 25, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 138
/
Copy pathtest_input_validation.py
131 lines (116 loc) · 4.3 KB
/
test_input_validation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from dffml.df.base import op
from dffml.df.types import DataFlow, Input, Definition
from dffml.operation.output import GetSingle
from dffml.util.asynctestcase import AsyncTestCase
from dffml.df.memory import MemoryOrchestrator
from dffml.operation.mapping import MAPPING
from dffml.df.exceptions import InputValidationError
def pie_validation(x):
if x == 3.14:
return x
raise InputValidationError()
Pie = Definition(name="pie", primitive="float", validate=pie_validation)
Radius = Definition(name="radius", primitive="float")
Area = Definition(name="area", primitive="float")
ShapeName = Definition(
name="shape_name", primitive="str", validate=lambda x: x.upper()
)
SHOUTIN = Definition(
name="shout_in", primitive="str", validate="validate_shout_instance"
)
SHOUTOUT = Definition(name="shout_out", primitive="str")
@op(
inputs={"name": ShapeName, "radius": Radius, "pie": Pie},
outputs={"shape": MAPPING},
)
async def get_circle(name: str, radius: float, pie: float):
return {
"shape": {
"name": name,
"radius": radius,
"area": pie * radius * radius,
}
}
@op(
inputs={"shout_in": SHOUTIN},
outputs={"shout_in_validated": SHOUTIN},
validator=True,
)
def validate_shouts(shout_in):
return {"shout_in_validated": shout_in + "_validated"}
@op(inputs={"shout_in": SHOUTIN}, outputs={"shout_out": SHOUTOUT})
def echo_shout(shout_in):
return {"shout_out": shout_in}
class TestDefintion(AsyncTestCase):
async def setUp(self):
await super().setUp()
self.dataflow = DataFlow(
operations={
"get_circle": get_circle.op,
"get_single": GetSingle.imp.op,
},
seed=[
Input(
value=[get_circle.op.outputs["shape"].name],
definition=GetSingle.op.inputs["spec"],
)
],
implementations={"get_circle": get_circle.imp},
)
async def test_validate(self):
test_inputs = {
"area": [
Input(value="unitcircle", definition=ShapeName),
Input(value=1, definition=Radius),
Input(value=3.14, definition=Pie),
]
}
async with MemoryOrchestrator.withconfig({}) as orchestrator:
async with orchestrator(self.dataflow) as octx:
async for ctx_str, results in octx.run(test_inputs):
self.assertIn("mapping", results)
results = results["mapping"]
self.assertEqual(results["name"], "UNITCIRCLE")
self.assertEqual(results["area"], 3.14)
self.assertEqual(results["radius"], 1)
async def test_validation_error(self):
with self.assertRaises(InputValidationError):
test_inputs = {
"area": [
Input(value="unitcircle", definition=ShapeName),
Input(value=1, definition=Radius),
Input(
value=4, definition=Pie
), # this should raise validation erorr
]
}
async def test_vaildation_by_op(self):
test_dataflow = DataFlow(
operations={
"validate_shout_instance": validate_shouts.op,
"echo_shout": echo_shout.op,
"get_single": GetSingle.imp.op,
},
seed=[
Input(
value=[echo_shout.op.outputs["shout_out"].name],
definition=GetSingle.op.inputs["spec"],
)
],
implementations={
validate_shouts.op.name: validate_shouts.imp,
echo_shout.op.name: echo_shout.imp,
},
)
test_inputs = {
"TestShoutOut": [
Input(value="validation_status:", definition=SHOUTIN)
]
}
async with MemoryOrchestrator.withconfig({}) as orchestrator:
async with orchestrator(test_dataflow) as octx:
async for ctx_str, results in octx.run(test_inputs):
self.assertIn("shout_out", results)
self.assertEqual(
results["shout_out"], "validation_status:_validated"
)