Skip to content

Commit

Permalink
Merge pull request #129 from uTensor/transpose-op
Browse files Browse the repository at this point in the history
Transpose op
  • Loading branch information
dboyliao committed Dec 1, 2020
2 parents 2f4aab2 + 5ea634f commit 8b575fa
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 1 deletion.
2 changes: 2 additions & 0 deletions utensor_cgen/backend/utensor/_graph_lower/_op_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def apply(cls, ugraph):
op_info.code_gen_attributes["namespaces"] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("SinOperator"):
op_info.code_gen_attributes["namespaces"] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("TransposeOperator"):
op_info.code_gen_attributes["namespaces"] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("ReshapeOperator"):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type("MatrixMultOperator"):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_construct_snippet(self, op_var_name):


@OperatorFactory.register
class SinOperator(_Operator):
class _SinOperator(_Operator):
namespaces = ('ReferenceOperators',)
op_type = 'SinOperator'

Expand Down Expand Up @@ -116,6 +116,42 @@ def get_construct_snippet(self, op_var_name):
)


@OperatorFactory.register
class _TransposeOperator(_Operator):
namespaces = ('ReferenceOperators',)
op_type = "TransposeOperator"

@classmethod
def get_type_signature(cls, op_info):
return ((op_info.input_tensors[0].dtype,), tuple())

def get_declare_snippet(self, op_var_name, with_const_params=True):
return DeclareOpSnippet(
self,
templ_dtypes=[self.in_dtypes[0],],
op_var_name=op_var_name,
nested_namespaces=self.namespaces,
with_const_params=with_const_params,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return TransposeEvalSnippet(
op_info,
templ_dtypes=[self.in_dtypes[0],],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
nested_namespaces=self.namespaces
)

def get_construct_snippet(self, op_var_name):
return OpConstructSnippet(
self,
templ_dtypes=[self.in_dtypes[0],],
op_var_name=op_var_name,
nested_namespaces=self.namespaces
)


@OperatorFactory.register
class _ReshapeOperator(_Operator):
namespaces = ('ReferenceOperators',)
Expand Down
6 changes: 6 additions & 0 deletions utensor_cgen/backend/utensor/snippets/rearch/_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"AddOpEvalSnippet",
"MulOpEvalSnippet",
"SinEvalSnippet",
"TransposeEvalSnippet",
"ReshahpeEvalSnippet",
"QuantizeEvalSnippet",
"MatrixMultEvalSnippet",
Expand Down Expand Up @@ -207,6 +208,11 @@ class SinEvalSnippet(OpEvalSnippet):
__outputs__ = ["act_out"]


class TransposeEvalSnippet(OpEvalSnippet):
__inputs__ = ["input", "perm"]
__outputs__ = ["output"]


class ReshahpeEvalSnippet(OpEvalSnippet):
__inputs__ = ["input"]
__outputs__ = ["output"]
Expand Down
11 changes: 11 additions & 0 deletions utensor_cgen/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from utensor_cgen.logger import logger
from utensor_cgen.utils import topologic_order_graph

# schema: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/schema/schema.fbs
from .tflite_flatbuffer.ActivationFunctionType import ActivationFunctionType
from .tflite_flatbuffer.BuiltinOperator import BuiltinOperator
from .tflite_flatbuffer.CustomOptionsFormat import CustomOptionsFormat
Expand Down Expand Up @@ -436,6 +437,15 @@ def argmax_op_data(op, fb_mdel):

return option_dict

def transpose_op_data(op, fb_model):
option_dict = {}
from .tflite_flatbuffer.TransposeOptions import TransposeOptions
# no filed declared in the fbs file for TransposeOptions
# skipping here
# this function is here just for silencing the warning msg
pass
return option_dict

def default_op_data(op, fb_mdel):
op_type = _get_op_type(op, fb_mdel)
logger.warning('the op data parser is missing for %s', op_type)
Expand All @@ -452,6 +462,7 @@ def default_op_data(op, fb_mdel):
_OP_DATA_FUNC_MAP["FULLY_CONNECTED"] = fully_connected_op_data
_OP_DATA_FUNC_MAP["DEQUANTIZE"] = dequantize_op_data
_OP_DATA_FUNC_MAP["ARG_MAX"] = argmax_op_data
_OP_DATA_FUNC_MAP["TRANSPOSE"] = transpose_op_data

def _get_op_type(op, fb_model):
local_op_code = op.OpcodeIndex()
Expand Down
1 change: 1 addition & 0 deletions utensor_cgen/legalizer/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class _OpTypeRename(object):
"Add": "AddOperator",
"Mul": "MulOperator",
"Sin": "SinOperator",
"Transpose": "TransposeOperator",
}

@classmethod
Expand Down

0 comments on commit 8b575fa

Please sign in to comment.