-
Notifications
You must be signed in to change notification settings - Fork 57
/
test_mlcontext.py
112 lines (81 loc) · 2.77 KB
/
test_mlcontext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import pytest
from xdsl.dialects.builtin import UnregisteredAttr, UnregisteredOp
from xdsl.ir import MLContext, Operation, ParametrizedAttribute, TypeAttribute
from xdsl.irdl import irdl_attr_definition
class DummyOp(Operation):
name = "dummy"
class DummyOp2(Operation):
name = "dummy2"
@irdl_attr_definition
class DummyAttr(ParametrizedAttribute):
name = "dummy_attr"
@irdl_attr_definition
class DummyAttr2(ParametrizedAttribute):
name = "dummy_attr2"
def test_get_op():
"""Test `get_op` and `get_optional_op` methods."""
ctx = MLContext()
ctx.load_op(DummyOp)
assert ctx.get_op("dummy") == DummyOp
with pytest.raises(Exception):
_ = ctx.get_op("dummy2")
assert ctx.get_optional_op("dummy") == DummyOp
assert ctx.get_optional_op("dummy2") is None
def test_get_op_unregistered():
"""
Test `get_op` and `get_optional_op`
methods with the `allow_unregistered` flag.
"""
ctx = MLContext(allow_unregistered=True)
ctx.load_op(DummyOp)
assert ctx.get_optional_op("dummy") == DummyOp
op = ctx.get_optional_op("dummy2")
assert op is not None
assert issubclass(op, UnregisteredOp)
assert ctx.get_op("dummy") == DummyOp
assert issubclass(ctx.get_op("dummy2"), UnregisteredOp)
def test_get_attr():
"""Test `get_attr` and `get_optional_attr` methods."""
ctx = MLContext()
ctx.load_attr(DummyAttr)
assert ctx.get_attr("dummy_attr") == DummyAttr
with pytest.raises(Exception):
_ = ctx.get_attr("dummy_attr2")
assert ctx.get_optional_attr("dummy_attr") == DummyAttr
assert ctx.get_optional_attr("dummy_attr2") is None
@pytest.mark.parametrize("is_type", [True, False])
def test_get_attr_unregistered(is_type: bool):
"""
Test `get_attr` and `get_optional_attr`
methods with the `allow_unregistered` flag.
"""
ctx = MLContext(allow_unregistered=True)
ctx.load_attr(DummyAttr)
assert (
ctx.get_optional_attr("dummy_attr", create_unregistered_as_type=is_type)
== DummyAttr
)
attr = ctx.get_optional_attr("dummy_attr2")
assert attr is not None
assert issubclass(attr, UnregisteredAttr)
if is_type:
assert issubclass(attr, TypeAttribute)
assert ctx.get_attr("dummy_attr", create_unregistered_as_type=is_type) == DummyAttr
assert issubclass(
ctx.get_attr("dummy_attr2", create_unregistered_as_type=is_type),
UnregisteredAttr,
)
if is_type:
assert issubclass(attr, TypeAttribute)
def test_clone_function():
ctx = MLContext()
ctx.load_attr(DummyAttr)
ctx.load_op(DummyOp)
copy = ctx.clone()
assert ctx == copy
copy.load_op(DummyOp2)
assert ctx != copy
copy = ctx.clone()
assert ctx == copy
copy.load_attr(DummyAttr2)
assert ctx != copy