Skip to content

Commit

Permalink
[FONTEND][TENSORFLOW] Fixes and enhancements for NCHW support - Resne…
Browse files Browse the repository at this point in the history
…t_v2
  • Loading branch information
srkreddy1238 committed Sep 19, 2018
1 parent 8c5d3ef commit 1fe792e
Showing 1 changed file with 37 additions and 15 deletions.
52 changes: 37 additions & 15 deletions nnvm/python/nnvm/frontend/tensorflow.py
Expand Up @@ -109,29 +109,24 @@ def _elemwise(name):
def _impl(inputs, attr, *args):
assert len(inputs) == 2, "Math op take 2 inputs, {} given".format(len(inputs))
op_name = _math_name_picker(name)(attr)
axis = int(attr.get('axis', 0))
conv_ops = ["conv2d", "conv2d_transpose"]
if op_name == 'broadcast_add' and inputs[0].attr('op_name') in conv_ops:
# TODO: remove hard coded infershape
inputs[1] = _sym.expand_dims(inputs[1], axis=axis, num_newaxis=2)
return get_nnvm_op(op_name)(*inputs)
return _impl

def _pooling(name):
def _impl(inputs, attr, params):

attr['data_format'] = attr['data_format'].decode("utf-8")

if attr['data_format'] == 'NHWC':
attr['kernel_shape'] = (attr['ksize'][1], attr['ksize'][2])
# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
attr['kernel_shape'] = (attr['ksize'][2], attr['ksize'][3])
# Fix strides
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data_format type : {}".format(attr['data_format']))

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
input_shapes = attr['_input_shapes'][inputs[0]]
attr['padding'] = attr['padding'].decode("utf-8")
Expand Down Expand Up @@ -171,6 +166,11 @@ def _impl(inputs, attr, params):
def _conv(opname):
def _impl(inputs, attr, params):
attr['data_format'] = attr['data_format'].decode("utf-8")
# NCHW Layout require weights transpose
if attr['data_format'] == 'NCHW':
weights = params.pop(inputs[1].list_output_names()[0]).asnumpy()
params[inputs[1].list_output_names()[0]] = np.transpose(weights, (3, 2, 0, 1))

input_shapes = attr['_input_shapes'][inputs[0]]

# Extract kernel shape from params
Expand All @@ -186,26 +186,29 @@ def _impl(inputs, attr, params):

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][0], attr['dilations'][1])

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])
elif attr['data_format'] == 'NCHW':
depth_mult, _, kernel_h, kernel_w = conv_param_weights.shape
attr['kernel_shape'] = (conv_param_weights.shape[2], conv_param_weights.shape[3])
if opname == 'conv':
attr['channels'] = conv_param_weights.shape[1]
attr['channels'] = conv_param_weights.shape[0]
else:
attr['channels'] = input_shapes[0][1] * depth_mult

if 'dilations' in attr:
attr['dilations'] = (attr['dilations'][2], attr['dilations'][3])

# Fix strides
attr['strides'] = (attr['strides'][2], attr['strides'][3])
else:
raise TypeError("Unsupported data format type : {}".format(attr['data_format']))


if opname == 'depthwise':
attr['groups'] = attr['channels']

# Fix strides
attr['strides'] = (attr['strides'][1], attr['strides'][2])

# Fix padding
attr['padding'] = attr['padding'].decode("utf-8")

Expand Down Expand Up @@ -381,10 +384,15 @@ def _fused_batch_norm():
def _impl(inputs, attr, params):
# Tensorflow: (data, gamma, beta, moving_mean, moving_variance)
# NNVM: (data, gamma, beta, moving_mean, moving_varience)
attr['data_format'] = attr['data_format'].decode("utf-8")
axis = 3
if attr['data_format'] == 'NCHW':
axis = 1

return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis
extras={'axis': axis},
ignores=['data_format'],
disables=['momentum'])(inputs, attr)
return _impl
Expand All @@ -397,10 +405,15 @@ def _impl(inputs, attr, params):
# (data, gamma, beta, moving_mean, moving_var)
new_inputs = [inputs[0], inputs[4], inputs[3], inputs[1], inputs[2]]

attr['data_format'] = attr['data_format'].decode("utf-8")
axis = 3
if attr['data_format'] == 'NCHW':
axis = 1

return AttrCvt(
op_name='batch_norm',
transforms={'scale_after_normalization':'scale', 'variance_epsilon':'epsilon'},
extras={'axis': 3}, # Fix axis
extras={'axis': axis},
ignores=['data_format'],
disables=['momentum'])(new_inputs, attr)
return _impl
Expand Down Expand Up @@ -694,6 +707,14 @@ def _impl(inputs, attr, params):
return gamma * (-alpha * _sym.relu(1 - _sym.exp(inputs[0])) + _sym.relu(inputs[0]))
return _impl

def _mean():
def _impl(inputs, attr, params):
axis = params.pop(inputs[1].list_output_names()[0])
return AttrCvt(op_name="mean", ignores=['Tdim', 'Tidx'],
transforms={'keep_dims': 'keepdims'},
extras={'axis': axis.asnumpy()[0]})(inputs[0], attr)
return _impl

# compatible operators that do NOT require any conversion.
_identity_list = []

Expand Down Expand Up @@ -752,6 +773,7 @@ def _impl(inputs, attr, params):
'Rank' : _rank(),
'Transpose' : _transpose(),
'Tanh' : AttrCvt('tanh'),
'Mean' : _mean(),
}

# _convert_map_rnn defines maps of rnn operator name to
Expand Down

0 comments on commit 1fe792e

Please sign in to comment.