Skip to content

Commit

Permalink
[TENSORFLOW]Conv3d Transpose OP added (apache#5775)
Browse files Browse the repository at this point in the history
* [TENSORFLOW]Conv3d Transpose OP added

* Testcase updated, tf cpu supports only ndhwc
  • Loading branch information
siju-samuel authored and Trevor Morris committed Jun 18, 2020
1 parent 1655e81 commit 7889e68
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 1 deletion.
3 changes: 2 additions & 1 deletion python/tvm/relay/frontend/tensorflow.py
Expand Up @@ -603,7 +603,7 @@ def _impl(inputs, attr, params, mod):
out = AttrCvt(
op_name=_dimension_picker('conv',
surfix="_transpose" if opname == 'conv_transpose' else ""),
ignores=['explicit_paddings'],
ignores=['explicit_paddings', 'Tshape'],
transforms={
'kernel_shape': 'kernel_size',
'data_format': 'data_layout',
Expand Down Expand Up @@ -2046,6 +2046,7 @@ def _impl(inputs, attr, params, mod):
'Conv2D' : _conv('conv'),
'Conv2DBackpropInput' : _conv('conv_transpose'),
'Conv3D' : _conv3d('conv'),
'Conv3DBackpropInputV2' : _conv3d('conv_transpose'),
'Cos' : AttrCvt('cos'),
'Cosh' : AttrCvt('cosh'),
'CropAndResize' : _crop_and_resize(),
Expand Down
87 changes: 87 additions & 0 deletions tests/python/frontend/tensorflow/test_forward.py
Expand Up @@ -533,6 +533,92 @@ def test_forward_convolution3d():
_test_convolution3d('conv', [4, 17, 17, 17, 12], [3, 3, 3, 12, 32], [1, 1, 1], [2, 2, 2], 'VALID', 'NDHWC')


#######################################################################
# Convolution3D Transpose
# -----------------------

def _test_convolution3d_transpose(data_shape, filter_shape, strides,
padding, output_shape, data_format='NCDHW'):
""" One iteration of 3D convolution transpose with given shapes and attributes """

dtype = 'float32'
data_array = np.random.uniform(size=data_shape).astype(dtype)
filter_array = np.random.uniform(size=filter_shape).astype(dtype)
if data_format == 'NDHWC':
strides = [1] + strides + [1]
else:
strides = [1, 1] + strides

with tf.Graph().as_default():
in_data = array_ops.placeholder(shape=data_shape, dtype=dtype)
in_filter = constant_op.constant(
filter_array, shape=filter_shape, dtype=dtype)

nn_ops.conv3d_transpose(in_data,
in_filter,
output_shape=output_shape,
strides=strides,
padding=padding,
data_format=data_format)

compare_tf_with_tvm(data_array, 'Placeholder:0', 'conv3d_transpose:0', cuda_layout="NDHWC")


def test_forward_convolution3d_transpose():
if is_gpu_available():
_test_convolution3d_transpose(data_shape=[1, 10, 8, 8, 8],
filter_shape=[1, 1, 1, 6, 10],
strides=[1, 1, 1],
padding='VALID',
output_shape=[1, 6, 8, 8, 8])

_test_convolution3d_transpose(data_shape=[4, 9, 8, 8, 8],
filter_shape=[1, 1, 1, 6, 9],
strides=[1, 1, 1],
padding='VALID',
output_shape=[4, 6, 8, 8, 8])

_test_convolution3d_transpose(data_shape=[1, 3, 8, 8, 8],
filter_shape=[1, 1, 1, 6, 3],
strides=[2, 2, 2],
padding='SAME',
output_shape=[1, 6, 15, 15, 15])

_test_convolution3d_transpose(data_shape=[1, 16, 8, 8, 8],
filter_shape=[3, 3, 3, 6, 16],
strides=[3, 3, 3],
padding='VALID',
output_shape=[1, 6, 24, 24, 24])

_test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 10],
filter_shape=[1, 1, 1, 6, 10],
strides=[1, 1, 1],
padding='VALID',
output_shape=[1, 8, 8, 8, 6],
data_format='NDHWC')

_test_convolution3d_transpose(data_shape=[4, 8, 8, 8, 9],
filter_shape=[1, 1, 1, 6, 9],
strides=[1, 1, 1],
padding='VALID',
output_shape=[4, 8, 8, 8, 6],
data_format='NDHWC')

_test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 3],
filter_shape=[1, 1, 1, 6, 3],
strides=[2, 2, 2],
padding='SAME',
output_shape=[1, 15, 15, 15, 6],
data_format='NDHWC')

_test_convolution3d_transpose(data_shape=[1, 8, 8, 8, 16],
filter_shape=[3, 3, 3, 6, 16],
strides=[3, 3, 3],
padding='VALID',
output_shape=[1, 24, 24, 24, 6],
data_format='NDHWC')


#######################################################################
# BiasAdd
# -----------
Expand Down Expand Up @@ -3728,6 +3814,7 @@ def test_forward_spop():
# NN
test_forward_convolution()
test_forward_convolution3d()
test_forward_convolution3d_transpose()
test_forward_pooling()
test_forward_concat_v2()
test_forward_lrn()
Expand Down

0 comments on commit 7889e68

Please sign in to comment.