Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added atrous_conv1d and 3d. Refactored 2d. #7545

Merged
merged 6 commits into from
May 2, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tensorflow/python/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
@@depthwise_conv2d
@@depthwise_conv2d_native
@@separable_conv2d
@@atrous_conv1d
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be removed.

@@atrous_conv2d
@@atrous_conv3d
@@atrous_conv2d_transpose
@@conv2d_transpose
@@conv1d
Expand Down
170 changes: 87 additions & 83 deletions tensorflow/python/ops/nn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,52 @@ def op(converted_input, _, converted_padding): # pylint: disable=missing-docstr
filter_shape=window_shape)


def atrous_conv2d(value, filters, rate, padding, name=None):
def atrous_conv1d(input, filter, rate, padding, strides=None, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).

Computes a 1-D atrous convolution, also known as convolution with holes or
dilated convolution, given 3-D `input` and `filter` tensors. If the `rate`
parameter is equal to one, it performs regular 1-D convolution. If the `rate`
parameter is greater than one, it performs convolution with holes, sampling
the input values every `rate` points.
This is equivalent to convolving the input with a set of upsampled filters,
produced by inserting `rate - 1` zeros between two consecutive values of the
filters along the `height` dimensions, hence the name atrous
convolution or convolution with holes (the French word trous means holes in
English).

See @{tf.nn.atrous_conv2d}

Args:
input: A 3-D `Tensor` of type `float`. It needs to be in the default "NHWC"
format. Its shape is `[batch, in_height, in_channels]`.
filter: A 3-D `Tensor` with the same type as `input` and shape
`[filter_height, in_channels, out_channels]`. `filter`'
`in_channels` dimension must match that of `input`. Atrous convolution is
equivalent to standard convolution with upsampled filters with effective
height `filter_height + (filter_height - 1) * (rate - 1)`, produced by
inserting `rate - 1` zeros along consecutive elements across the
`filter`' spatial dimensions.
rate: A positive int32. The stride with which we sample input values across
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
strides: Optional. Sequence of 1 ints >= 1. Specifies the output stride.
Defaults to [1]. If any value of strides is > 1, then all values of
dilation_rate must be 1.
name: Optional name for the returned tensor.
Returns:
A `Tensor` with the same type as `value`.
Raises:
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.

"""
return convolution(input=input, filter=filter, padding=padding, dilation_rate=np.broadcast_to(rate, (1, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long


def atrous_conv2d(value, filters, rate, padding, strides=None, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).

Computes a 2-D atrous convolution, also known as convolution with holes or
Expand Down Expand Up @@ -903,7 +948,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
called `input stride` or `dilation`. Altrenatively it could be 3 element sequence denoting dilation rate in each dimension.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Altrenatively/Alternatively

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All lines must have length at most 80 pixels.

padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
name: Optional name for the returned tensor.

Expand All @@ -914,94 +959,53 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
with ops.name_scope(name, "atrous_conv2d", [value, filters]) as name:
value = ops.convert_to_tensor(value, name="value")
filters = ops.convert_to_tensor(filters, name="filters")
if not value.get_shape()[3].is_compatible_with(filters.get_shape()[2]):
raise ValueError(
"value's input channels does not match filters' input channels, "
"{} != {}".format(value.get_shape()[3], filters.get_shape()[2]))
if rate < 1:
raise ValueError("rate {} cannot be less than one".format(rate))

if rate == 1:
value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding=padding)
return value

# We have two padding contributions. The first is used for converting "SAME"
# to "VALID". The second is required so that the height and width of the
# zero-padded value tensor are multiples of rate.

# Padding required to reduce to "VALID" convolution
if padding == "SAME":
# Handle filters whose shape is unknown during graph creation.
if filters.get_shape().is_fully_defined():
filter_shape = filters.get_shape().as_list()
else:
filter_shape = array_ops.shape(filters)
filter_height, filter_width = filter_shape[0], filter_shape[1]

# Spatial dimensions of the filters and the upsampled filters in which we
# introduce (rate - 1) zeros between consecutive filter values.
filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
filter_width_up = filter_width + (filter_width - 1) * (rate - 1)

pad_height = filter_height_up - 1
pad_width = filter_width_up - 1

# When pad_height (pad_width) is odd, we pad more to bottom (right),
# following the same convention as conv2d().
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
pad_left = pad_width // 2
pad_right = pad_width - pad_left
elif padding == "VALID":
pad_top = 0
pad_bottom = 0
pad_left = 0
pad_right = 0
else:
raise ValueError("Invalid padding")

# Handle input whose shape is unknown during graph creation.
if value.get_shape().is_fully_defined():
value_shape = value.get_shape().as_list()
else:
value_shape = array_ops.shape(value)

in_height = value_shape[1] + pad_top + pad_bottom
in_width = value_shape[2] + pad_left + pad_right

# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate
return convolution(input=value, filter=filters, padding=padding, dilation_rate=np.broadcast_to(rate, (2, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, we had kept the legacy atrous_conv2 as a separate op for backwards compatibility with already exported graphs. Can one of the TF admins comment if this is needed or not?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying BatchToSpace and SpaceToBatch ops are needed for backwards compatibility with exported graphs. Changing the Python wrappers (which change how new graphs are constructed) should not impact that.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long


# The paddings argument to space_to_batch includes both padding components.
space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra],
[pad_left, pad_right + pad_right_extra]]
def atrous_conv3d(input, filter, rate, padding, strides=None, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).

value = array_ops.space_to_batch(input=value,
paddings=space_to_batch_pad,
block_size=rate)
Computes a 3-D atrous convolution, also known as convolution with holes or
dilated convolution, given 5-D `input` and `filter` tensors. If the `rate`
parameter is equal to one, it performs regular 3-D convolution. If the `rate`
parameter is greater than one, it performs convolution with holes, sampling
the input values every `rate` points.
This is equivalent to convolving the input with a set of upsampled filters,
produced by inserting `rate - 1` zeros between two consecutive values of the
filters along the `height`, `width` and `depth` dimensions, hence the name atrous convolution or convolution with holes (the French word trous means holes in
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

English).

value = gen_nn_ops.conv2d(input=value,
filter=filters,
strides=[1, 1, 1, 1],
padding="VALID",
name=name)
See @{tf.nn.atrous_conv2d}

# The crops argument to batch_to_space is just the extra padding component.
batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]]
Args:
input: A 5-D `Tensor` of type `float`. It needs to be in the default "NHWC"
format. Its shape is `[batch, in_height, in_width, in_depth, in_channels]`.
filter: A 5-D `Tensor` with the same type as `input` and shape
`[filter_height, filter_width, filter_depth, in_channels, out_channels]`. `filter`'
`in_channels` dimension must match that of `input`. Atrous convolution is
equivalent to standard convolution with upsampled filters with effective
height `filter_height + (filter_height - 1) * (rate - 1)`, produced by
inserting `rate - 1` zeros along consecutive elements across the
`filter`' spatial dimensions.
rate: A positive int32. The stride with which we sample input values across
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`. Altrenatively it could be 3 element sequence denoting dilation rate in each dimension.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
strides: Optional. Sequence of 3 ints >= 1. Specifies the output stride.
Defaults to [1]*3. If any value of strides is > 1, then all values of
dilation_rate must be 1.
name: Optional name for the returned tensor.

value = array_ops.batch_to_space(input=value,
crops=batch_to_space_crop,
block_size=rate)
Returns:
A `Tensor` with the same type as `value`.

return value
Raises:
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.

"""
return convolution(input=input, filter=filter, padding=padding, dilation_rate=np.broadcast_to(rate, (3, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long


def conv2d_transpose(value,
filter,
Expand Down