Skip to content

Commit

Permalink
Merge pull request #135 from victorromeo/avg_pool_op
Browse files Browse the repository at this point in the history
AvgPool and MinPool support
  • Loading branch information
dboyliao committed Nov 30, 2020
2 parents ce14b06 + 847f184 commit 1289bc3
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions utensor_cgen/backend/utensor/_graph_lower/_op_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def apply(cls, ugraph):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type('MaxOperator'):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type('AvgPoolOperator'):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type('MaxPoolOperator'):
op_info.code_gen_attributes['namespaces'] = ('ReferenceOperators',)
for op_info in ugraph.get_ops_by_type('MinPoolOperator'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,36 @@ def get_constructor_parameters(cls, op_info):
ksize_str = _c_arr_str(ksize)
return (ksize_str, stride_str, padding)

@OperatorFactory.register
class _AvgPoolOperator(_PoolingOperatorMixin, _Operator):
namespaces = ('ReferenceOperators',)
op_type = 'AvgPoolOperator'

def get_declare_snippet(self, op_var_name, with_const_params=True):
return DeclareOpSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=type(self).namespaces,
with_const_params=with_const_params,
)

def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
return AvgPoolEvalSnippet(
op_info=op_info,
templ_dtypes=[self.in_dtypes[0]],
op_name=op_var_name,
tensor_var_map=tensor_var_map,
nested_namespaces=type(self).namespaces,
)

def get_construct_snippet(self, op_var_name):
return OpConstructSnippet(
op=self,
templ_dtypes=[self.in_dtypes[0]],
op_var_name=op_var_name,
nested_namespaces=type(self).namespaces,
)

@OperatorFactory.register
class _MaxPoolOperator(_PoolingOperatorMixin, _Operator):
Expand Down
4 changes: 4 additions & 0 deletions utensor_cgen/backend/utensor/snippets/rearch/_snippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"ReLU6EvalSnippet",
"MinEvalSnippet",
"MaxEvalSnippet",
"AvgPoolEvalSnippet",
"MinPoolEvalSnippet",
"MaxPoolEvalSnippet",
"QuantizedFullyConnectedSnippet",
Expand Down Expand Up @@ -255,6 +256,9 @@ class MaxEvalSnippet(OpEvalSnippet):
__inputs__ = ["in"]
__outputs__ = ["out"]

class AvgPoolEvalSnippet(OpEvalSnippet):
__inputs__ = ["in"]
__outputs__ = ["out"]

class MinPoolEvalSnippet(OpEvalSnippet):
__inputs__ = ["in"]
Expand Down
2 changes: 2 additions & 0 deletions utensor_cgen/frontend/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def default_op_data(op, fb_mdel):
_OP_DATA_FUNC_MAP["QUANTIZE"] = quantize_op_data
_OP_DATA_FUNC_MAP["DEPTHWISE_CONV_2D"] = depthwise_conv2d_op_data
_OP_DATA_FUNC_MAP["CONV_2D"] = conv_2d_op_data
_OP_DATA_FUNC_MAP["AVG_POOL_2D"] = pool2d_op_data
_OP_DATA_FUNC_MAP["MIN_POOL_2D"] = pool2d_op_data
_OP_DATA_FUNC_MAP["MAX_POOL_2D"] = pool2d_op_data
_OP_DATA_FUNC_MAP["RESHAPE"] = reshape_op_data
_OP_DATA_FUNC_MAP["FULLY_CONNECTED"] = fully_connected_op_data
Expand Down
2 changes: 2 additions & 0 deletions utensor_cgen/legalizer/tflite.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ class _OpTypeRename(object):
"FullyConnected": "FullyConnectedOperator",
"Quantize": "QuantizeOperator",
"DepthwiseConv2d": "DepthwiseSeparableConvOperator",
"AvgPool2d": "AvgPoolOperator",
"MinPool2d": "MinPoolOperator",
"MaxPool2d": "MaxPoolOperator",
"Dequantize": "DequantizeOperator",
"Reshape": "ReshapeOperator",
Expand Down

0 comments on commit 1289bc3

Please sign in to comment.