Skip to content

Commit

Permalink
support matching attributes with more complext objects (apache#8240)
Browse files Browse the repository at this point in the history
  • Loading branch information
Matthew Brookhart authored and Trevor Morris committed Jun 17, 2021
1 parent f6ceaa5 commit df84e19
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 0 deletions.
11 changes: 11 additions & 0 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,17 @@ Here is another example to match an op with a specific attribute:
y = relay.var('y')
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))
Or a convolution with a specific kernel size:

.. code-block:: python
def test_match_kernel_size():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
x = relay.var('x')
y = relay.var('y')
assert is_conv2d.match(relay.op.nn.conv2d(x, y, kernel_size=[3, 3]))
Matching an Optional Op
***********************
Expand Down
9 changes: 9 additions & 0 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,8 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
return rhs.operator std::string() == val->value;
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator std::string() == val->data;
} else {
ICHECK(false) << "PatternMatcher: Unsupported TVMDataType " << lhs;
}
break;
case kTVMObjectHandle:
Expand All @@ -140,6 +142,13 @@ bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
} else if (auto* val = lhs.as<StringObj>()) {
return rhs.operator String() == val->data;
}
} else {
// Compare the objects for structural equality
static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
if ((*structural_equal)(lhs, GetRef<ObjectRef>(rhs.ptr<Object>()), false, true)) {
return true;
}
}
break;
default:
Expand Down
11 changes: 11 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,17 @@ def test_no_match_func_attr():


def test_match_call_attr():
# String attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var("x")
y = relay.var("y")
assert is_conv2d.match(relay.op.nn.conv2d(x, y))

# Array attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
out = relay.op.nn.conv2d(x, y, kernel_size=[3, 3])
assert is_conv2d.match(out)

# non-operator call
attr_dict = {"call_attr": "attr"}
call_has_attr = wildcard()(wildcard()).has_attr(attr_dict)
Expand All @@ -508,6 +514,11 @@ def test_no_match_call_attr():
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

# Array attr
is_conv2d = is_op("nn.conv2d")(wildcard(), wildcard()).has_attr({"kernel_size": [3, 3]})
out = relay.op.nn.conv2d(x, y, kernel_size=[2, 1])
assert not is_conv2d.match(out)

# non-operator calls
call_has_attr = wildcard()(wildcard()).has_attr({"call_attr": "attr"})
wrong_key = tvm.ir.make_node("DictAttrs", **{"wrong": "attr"})
Expand Down

0 comments on commit df84e19

Please sign in to comment.