Skip to content

Commit

Permalink
Add op for utensor backend: Add, Sin, Mul
Browse files Browse the repository at this point in the history
  • Loading branch information
dboyliao committed Nov 10, 2020
1 parent 3bacf8b commit c345e20
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 0 deletions.
4 changes: 4 additions & 0 deletions utensor_cgen/backend/utensor/_graph_lower/_op_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ def apply(cls, ugraph):
# TODO: better abstraction, sth like lowering strategy
for op_info in ugraph.get_ops_by_type("AddOperator"):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("MulOperator"):
op_info.code_gen_attributes["namespaces"] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("SinOperator"):
op_info.code_gen_attributes["namespaces"] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("ReshapeOperator"):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("MatrixMultOperator"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,74 @@ def get_construct_snippet(self, op_var_name):
)


@OperatorFactory.register
class _AddOperator(_Operator):
namespaces = ('ReferenceOperators',)
op_type = 'MulOperator'

def get_declare_snippet(self, op_var_name, with_const_params=True):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=type(self).namespaces,
with_const_params=with_const_params,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return MulOpEvalSnippet(
op_info=op_info,
templ_dtypes=[self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
nested_namespaces=type(self).namespaces,
)

def get_construct_snippet(self, op_var_name):
return OpConstructSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=type(self).namespaces,
)


@OperatorFactory.register
class SinOperator(_Operator):
namespaces = ('ReferenceOperators',)
op_type = 'SinOperator'

@classmethod
def get_type_signature(cls, op_info):
return ((op_info.output_tensors[0].dtype,), (op_info.input_tensors[0].dtype,))

def get_declare_snippet(self, op_var_name, with_const_params=True):
return DeclareOpSnippet(
self,
templ_dtypes=[self.out_dtypes[0], self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=self.namespaces,
with_const_params=with_const_params,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return SinEvalSnippet(
op_info,
templ_dtypes=[self.out_dtypes[0], self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
nested_namespaces=self.namespaces
)

def get_construct_snippet(self, op_var_name):
return OpConstructSnippet(
self,
templ_dtypes=[self.out_dtypes[0], self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=self.namespaces
)


@OperatorFactory.register
class _ReshapeOperator(_Operator):
namespaces = ('ReferenceOperators',)
Expand Down
12 changes: 12 additions & 0 deletions utensor_cgen/backend/utensor/snippets/rearch/_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"DepthwiseSeperateConvOpEvalSnippet",
"QuantDepthwiseSeperateConvOpEvalSnippet",
"AddOpEvalSnippet",
"MulOpEvalSnippet",
"SinEvalSnippet",
"ReshahpeEvalSnippet",
"QuantizeEvalSnippet",
"MatrixMultEvalSnippet",
Expand Down Expand Up @@ -194,6 +196,16 @@ class AddOpEvalSnippet(OpEvalSnippet):
__outputs__ = ['c']


class MulOpEvalSnippet(OpEvalSnippet):
__inputs__ = ['a', 'b']
__outputs__ = ['c']


class SinEvalSnippet(OpEvalSnippet):
__inputs__ = ["act_in"]
__outputs__ = ["act_out"]


class ReshahpeEvalSnippet(OpEvalSnippet):
__inputs__ = ["input"]
__outputs__ = ["output"]
Expand Down
3 changes: 3 additions & 0 deletions utensor_cgen/legalizer/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class _OpTypeRename(object):
"Dequantize": "DequantizeOperator",
"Reshape": "ReshapeOperator",
"Conv2d": "Conv2dOperator",
"Add": "AddOperator",
"Mul": "MulOperator",
"Sin": "SinOperator",
}

@classmethod
Expand Down

0 comments on commit c345e20

Please sign in to comment.