Skip to content

Commit

Permalink
Minor refactor: moving _MissingOperator to _base.py module
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Nov 30, 2020
1 parent 1289bc3 commit 2f4aab2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from six import with_metaclass

from utensor_cgen.backend.utensor.snippets.rearch import MissingOpEvalSnippet
from utensor_cgen.logger import logger
from utensor_cgen.utils import MUST_OVERWRITE, must_return_type

Expand All @@ -20,7 +21,6 @@ def get_opertor(cls, op_info):
codegen_namespaces = op_info.code_gen_attributes.get('namespaces', tuple())
op_cls = cls._operators.get((codegen_namespaces, op_type))
if op_cls is None:
missing_op_cls = cls._operators['_MissingOperator']
if op_info.op_type not in cls._warned_missing_ops:
op_full_name = '::'.join(
["uTensor"] + \
Expand All @@ -31,7 +31,7 @@ def get_opertor(cls, op_info):
'{} is missing, no code will be generated for it'.format(op_full_name)
)
cls._warned_missing_ops.add(op_info.op_type)
return missing_op_cls(op_info)
return _MissingOperator(op_info)
return op_cls(op_info)

@classmethod
Expand Down Expand Up @@ -198,3 +198,16 @@ def get_construct_snippet(self, op_var_name):
raise NotImplementedError(
"base get_construct_snippet invoked: {}".format(type(self))
)


class _MissingOperator(_Operator):
op_type = "_MissingOperator"

def get_declare_snippet(self, op_var_name, with_const_params=True):
return None

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return MissingOpEvalSnippet(op_info, op_var_name, tensor_var_map)

def get_construct_snippet(self, op_var_name):
return None
Original file line number Diff line number Diff line change
Expand Up @@ -815,19 +815,3 @@ def get_construct_snippet(self, op_var_name):
op_var_name=op_var_name,
nested_namespaces=type(self).namespaces,
)


class _MissingOperator(_Operator):
op_type = "_MissingOperator"

def get_declare_snippet(self, op_var_name, with_const_params=True):
return None

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return MissingOpEvalSnippet(op_info, op_var_name, tensor_var_map)

def get_construct_snippet(self, op_var_name):
return None


OperatorFactory._operators[_MissingOperator.op_type] = _MissingOperator

0 comments on commit 2f4aab2

Please sign in to comment.