forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 4
/
dnnl.py
87 lines (72 loc) · 2.88 KB
/
dnnl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, unused-argument
"""DNNL library supported operators.
There are two ways to registering a function for an op to indicate if it is
supported by DNNL.
- The first and simplest way is to use the helper so that
users only need to provide the operator name and a boolean value to indicate if
it is supported. For example:
.. code-block:: python
add = _register_external_op_helper("add")
add = _register_external_op_helper("add", True)
add = _register_external_op_helper("add", False)
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
import tvm.ir
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table
def _register_external_op_helper(op_name, supported=True):
"""The helper function to indicate that a given operator can be supported
by DNNL.
Paramters
---------
op_name : Str
The name of operator that will be registered.
Returns
-------
f : callable
A function that returns if the operator is supported by DNNL.
"""
@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
return supported
return _func_wrapper
_register_external_op_helper("nn.batch_norm")
_register_external_op_helper("nn.conv2d")
_register_external_op_helper("nn.dense")
_register_external_op_helper("nn.relu")
_register_external_op_helper("add")
_register_external_op_helper("subtract")
_register_external_op_helper("multiply")
def make_pattern(with_bias=True):
data = wildcard()
weight = wildcard()
bias = wildcard()
conv = is_op('nn.conv2d')(data, weight)
if with_bias:
conv_out = is_op('add')(conv, bias)
else:
conv_out = conv
return is_op('nn.relu')(conv_out)
@register_pattern_table("dnnl")
def pattern_table():
conv2d_bias_relu_pat = ("dnnl.conv2d_bias_relu", make_pattern(with_bias=True))
conv2d_relu_pat = ("dnnl.conv2d_relu", make_pattern(with_bias=False))
dnnl_patterns = [conv2d_bias_relu_pat, conv2d_relu_pat]
return dnnl_patterns