Skip to content

Commit

Permalink
[REFACTOR][PY] relay.op.Op -> tvm.ir.Op (apache#5705)
Browse files Browse the repository at this point in the history
* [REFACTOR][PY] relay.op.Op -> tvm.ir.Op

* Improve the error check
  • Loading branch information
tqchen authored and Trevor Morris committed Jun 18, 2020
1 parent 1f257b2 commit 349b916
Show file tree
Hide file tree
Showing 31 changed files with 215 additions and 198 deletions.
6 changes: 3 additions & 3 deletions include/tvm/ir/op.h
Expand Up @@ -121,7 +121,7 @@ class OpNode : public RelayExprNode {
return is_primitive_ != 0;
}

static constexpr const char* _type_key = "relay.Op";
static constexpr const char* _type_key = "Op";
TVM_DECLARE_FINAL_OBJECT_INFO(OpNode, RelayExprNode);

private:
Expand Down Expand Up @@ -180,7 +180,7 @@ class Op : public RelayExpr {
* \tparam ValueType The type of the attribute.
*/
template <typename ValueType>
inline static OpAttrMap<ValueType> GetAttrMap(const std::string& attr_name);
inline static OpAttrMap<ValueType> GetAttrMap(const String& attr_name);
/*!
* \brief Checks if an attr map is present in the registry.
* \param attr_name The name of the attribute.
Expand Down Expand Up @@ -374,7 +374,7 @@ class OpAttrMap : public AttrRegistryMap<Op, ValueType> {
inline const OpNode* Op::operator->() const { return static_cast<const OpNode*>(get()); }

template <typename ValueType>
inline OpAttrMap<ValueType> Op::GetAttrMap(const std::string& key) {
inline OpAttrMap<ValueType> Op::GetAttrMap(const String& key) {
return OpAttrMap<ValueType>(Op::GetAttrMapContainer(key));
}

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Expand Up @@ -81,7 +81,7 @@ def __init__(self, graph, input_shapes, records, target_ops,
Each row of this file is an encoded record pair.
Otherwise, it is an iterator.
target_ops : List of relay.op.Op
target_ops : List of tvm.ir.Op
Target tuning operators.
target : str or tvm.target
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Expand Up @@ -38,7 +38,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
expr : tvm.relay.Expr.Function
Input relay function expression.
target_ops: List of relay.op.Op
target_ops: List of tvm.ir.Op
List of target relay ops
node_dict : dictionary from tvm.relay.Expr to int
Expand Down Expand Up @@ -157,7 +157,7 @@ def _traverse_expr(node):
elif isinstance(node, Constant):
node_entry["name"] = "Constant_" + str(node_index)
node_entry["types"] = [node.checked_type]
elif isinstance(node, relay.op.op.Op):
elif isinstance(node, tvm.ir.Op):
return
else:
raise RuntimeError("Not supported relay node type in graph tuning: %s"
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Expand Up @@ -78,7 +78,7 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
The compilation target
target_host: tvm.target.Target
The host compilation target
ops: List[relay.op.Op] or None
ops: List[tvm.ir.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
Expand All @@ -105,7 +105,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
The compilation target
target_host: tvm.target.Target
The host compilation target
ops: List[relay.op.Op] or None
ops: List[tvm.ir.Op] or None
List of relay ops to be tuned. If not specified, all tunable ops will be extracted.
Returns
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/topi_integration.py
Expand Up @@ -61,7 +61,7 @@ def reset(self, wanted_relay_ops=None):
Parameters
----------
wanted_relay_ops: List of relay.op.Op
wanted_relay_ops: List of tvm.ir.Op
The relay ops to be extracted
"""
self.task_collection = []
Expand Down
1 change: 1 addition & 0 deletions python/tvm/ir/__init__.py
Expand Up @@ -23,6 +23,7 @@
from .tensor_type import TensorType
from .type_relation import TypeCall, TypeRelation
from .expr import BaseExpr, PrimExpr, RelayExpr, GlobalVar, Range
from .op import Op, register_op_attr
from .function import CallingConv, BaseFunc
from .adt import Constructor, TypeData
from .module import IRModule
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/ir/json_compact.py
Expand Up @@ -109,7 +109,7 @@ def _convert(item, nodes):
# Base IR
"SourceName": _update_global_key,
"EnvFunc": _update_global_key,
"relay.Op": _update_global_key,
"relay.Op": [_update_global_key, _rename("Op")],
"relay.TypeVar": [_ftype_var, _update_from_std_str("name_hint")],
"relay.Id": [_update_from_std_str("name_hint")],
"relay.GlobalTypeVar": [_ftype_var, _update_from_std_str("name_hint")],
Expand Down
114 changes: 114 additions & 0 deletions python/tvm/ir/op.py
@@ -0,0 +1,114 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
"""Primitive operators in the TVM IR."""
import tvm._ffi
from . expr import RelayExpr
from . import _ffi_api


@tvm._ffi.register_object("Op")
class Op(RelayExpr):
"""Primitive operator in the IR."""
def __init__(self):
raise RuntimeError("Cannot create op, use get instead")

@staticmethod
def get(op_name):
"""Get the Op for a given name
Parameters
----------
op_name : str
The operator name
Returns
-------
op : Op
The op of the corresponding name
"""
return _ffi_api.GetOp(op_name)

def get_attr(self, attr_name):
"""Get additional attribute about the operator.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _ffi_api.OpGetAttr(self, attr_name)

def set_attr(self, attr_name, value, plevel=10):
"""Set attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
value : object
The attribute value
plevel : int
The priority level
"""
_ffi_api.OpSetAttr(self, attr_name, value, plevel)

def reset_attr(self, attr_name):
"""Reset attribute about the operator.
Parameters
----------
attr_name : str
The attribute name
"""
_ffi_api.OpResetAttr(self, attr_name)


def register_op_attr(op_name, attr_key, value=None, level=10):
"""Register an operator property of an operator by name.
Parameters
----------
op_name : str
The name of operator
attr_key : str
The attribute name.
value : object, optional
The value to set
level : int, optional
The priority level
Returns
-------
fregister : function
Register function if value is not specified.
"""
def _register(v):
"""internal register function"""
_ffi_api.RegisterOpAttr(op_name, attr_key, v, level)
return v
return _register(value) if value is not None else _register
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Expand Up @@ -40,7 +40,6 @@
from .backend import vm

# Root operators
from .op import Op
from .op import nn
from .op import image
from .op import annotation
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/_parser.py
Expand Up @@ -378,7 +378,7 @@ def visitProg(self, ctx: RelayParser.ProgContext) -> Union[expr.Expr, IRModule]:
return self.module

# Exprs
def visitOpIdent(self, ctx) -> op.Op:
def visitOpIdent(self, ctx) -> tvm.ir.Op:
op_name = ".".join([name.getText() for name in ctx.CNAME()])
if op_name in FUNC_OPS:
return FuncOp(FUNC_OPS[op_name])
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/analysis/annotated_regions.py
Expand Up @@ -31,9 +31,9 @@ def __init__(self, expr, region_begin_op, region_end_op):
----------
expr : tvm.relay.Expr
The expression from which to construct the regions.
region_begin_op : tvm.relay.Op
region_begin_op : tvm.ir.Op
The region begin annotation.
region_end_op : tvm.relay.Op
region_end_op : tvm.ir.Op
The region end annotation.
"""
Expand Down
7 changes: 3 additions & 4 deletions python/tvm/relay/backend/compile_engine.py
Expand Up @@ -26,7 +26,6 @@
from ... import target as _target
from ... import autotvm
from .. import function as _function
from .. import op as _op
from .. import ty as _ty
from . import _backend

Expand Down Expand Up @@ -98,7 +97,7 @@ def get_valid_implementations(op, attrs, inputs, out_type, target):
Parameters
----------
op : relay.op.Op
op : tvm.ir.Op
Relay operator.
attrs : object
Expand Down Expand Up @@ -157,7 +156,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
Parameters
----------
op : relay.op.Op
op : tvm.ir.Op
Relay operator.
attrs : object
Expand Down Expand Up @@ -215,7 +214,7 @@ def select_implementation(op, attrs, inputs, out_type, target, use_autotvm=True)
@tvm._ffi.register_func("relay.backend.lower_call")
def lower_call(call, inputs, target):
"""Lower the call expression to op implementation and tensor outputs."""
assert isinstance(call.op, _op.Op)
assert isinstance(call.op, tvm.ir.Op)
op = call.op

# Prepare the call_node->checked_type(). For the call node inputs, we ensure that
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr.py
Expand Up @@ -234,7 +234,7 @@ class Call(ExprWithOp):
Parameters
----------
op: tvm.relay.Op or any tvm.relay.Expr with function type.
op: tvm.ir.Op or any tvm.relay.Expr with function type.
The operation to be called.
args: List[tvm.relay.Expr]
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/expr_functor.py
Expand Up @@ -16,13 +16,13 @@
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, invalid-name
"""The expression functor of Relay."""
from tvm.ir import Op

from .function import Function
from .expr import Call, Let, Var, GlobalVar
from .expr import If, Tuple, TupleGetItem, Constant
from .expr import RefCreate, RefRead, RefWrite
from .adt import Constructor, Match, Clause
from .op import Op

class ExprFunctor:
"""
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/__init__.py
Expand Up @@ -17,9 +17,9 @@
#pylint: disable=wildcard-import, redefined-builtin
"""Relay core operators."""
# operator defs
from .op import get, register, register_compute, register_gradient, \
from .op import get, register_compute, register_gradient, \
register_pattern, register_alter_op_layout, register_legalize, \
Op, OpPattern, OpStrategy, debug, register_external_compiler
OpPattern, OpStrategy, debug, register_external_compiler
from . import strategy

# Operators
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/relay/op/contrib/dnnl.py
Expand Up @@ -32,7 +32,7 @@
- The other way is to implement the function by themselves to
check the attributes of the op and decide if it should be offloaded to DNNL.
"""
from ... import op as _op
import tvm.ir
from ...dataflow_pattern import wildcard, is_op
from .register import register_pattern_table

Expand All @@ -51,7 +51,7 @@ def _register_external_op_helper(op_name, supported=True):
f : callable
A function that returns if the operator is supported by DNNL.
"""
@_op.register(op_name, "target.dnnl")
@tvm.ir.register_op_attr(op_name, "target.dnnl")
def _func_wrapper(attrs, args):
return supported

Expand Down

0 comments on commit 349b916

Please sign in to comment.