Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added atrous_conv1d and 3d. Refactored 2d. #7545

Merged
merged 6 commits into from
May 2, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 7 additions & 87 deletions tensorflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,11 @@ def op(converted_input, _, converted_padding): # pylint: disable=missing-docstr
def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).

This function is a simpler wrapper around the more general
@{tf.nn.convolution}, and exists only for backwards compatibility. You can
use @{tf.nn.convolution} to perform 1-D, 2-D, or 3-D atrous convolution.


Computes a 2-D atrous convolution, also known as convolution with holes or
dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
parameter is equal to one, it performs regular 2-D convolution. If the `rate`
Expand Down Expand Up @@ -914,93 +919,8 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
with ops.name_scope(name, "atrous_conv2d", [value, filters]) as name:
value = ops.convert_to_tensor(value, name="value")
filters = ops.convert_to_tensor(filters, name="filters")
if not value.get_shape()[3].is_compatible_with(filters.get_shape()[2]):
raise ValueError(
"value's input channels does not match filters' input channels, "
"{} != {}".format(value.get_shape()[3], filters.get_shape()[2]))
if rate < 1:
raise ValueError("rate {} cannot be less than one".format(rate))

if rate == 1:
value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding=padding)
return value

# We have two padding contributions. The first is used for converting "SAME"
# to "VALID". The second is required so that the height and width of the
# zero-padded value tensor are multiples of rate.

# Padding required to reduce to "VALID" convolution
if padding == "SAME":
# Handle filters whose shape is unknown during graph creation.
if filters.get_shape().is_fully_defined():
filter_shape = filters.get_shape().as_list()
else:
filter_shape = array_ops.shape(filters)
filter_height, filter_width = filter_shape[0], filter_shape[1]

# Spatial dimensions of the filters and the upsampled filters in which we
# introduce (rate - 1) zeros between consecutive filter values.
filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
filter_width_up = filter_width + (filter_width - 1) * (rate - 1)

pad_height = filter_height_up - 1
pad_width = filter_width_up - 1

# When pad_height (pad_width) is odd, we pad more to bottom (right),
# following the same convention as conv2d().
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
elif padding == "VALID":
pad_top = 0
pad_bottom = 0
pad_left = 0
pad_right = 0
else:
raise ValueError("Invalid padding")

# Handle input whose shape is unknown during graph creation.
if value.get_shape().is_fully_defined():
value_shape = value.get_shape().as_list()
else:
value_shape = array_ops.shape(value)

in_height = value_shape[1] + pad_top + pad_bottom
in_width = value_shape[2] + pad_left + pad_right

# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate

# The paddings argument to space_to_batch includes both padding components.
space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra],
[pad_left, pad_right + pad_right_extra]]

value = array_ops.space_to_batch(input=value,
paddings=space_to_batch_pad,
block_size=rate)

value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding="VALID",
name=name)

# The crops argument to batch_to_space is just the extra padding component.
batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]]

value = array_ops.batch_to_space(input=value,
crops=batch_to_space_crop,
block_size=rate)

return value
return convolution(input=value, filter=filters, padding=padding,
dilation_rate=np.broadcast_to(rate, (2, )), name=name)


def conv2d_transpose(value,
Expand Down