Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added atrous_conv1d and 3d. Refactored 2d. #7545

Merged
merged 6 commits into from
May 2, 2017

Conversation

nmiculinic
Copy link
Contributor

This commit makes following changes:

  • deleted most atrous_conv2d code to reuse existing tf.nn.convolution function
  • added atrous_conv1d and atrous_conv3d with similar API as atrous_conv2d
  • Added support for variable rate per dimension, e.g. for atrous_conv2d
    rate=2, or rate=[2,1] does different things. Former is equal to
    rate=[2,2]. rate_i determines dilation_rate in dimension i.
  • added strides support with same API as in tf.nn.convolution function

This commit makes more code deletions then additions. However
documentation per function makes it appear large.

Test Plan:

Some simple tests to verify I haven't screwed something up:

A = np.array([1, 2, 3, 4, 5, 6], dtype=np.float32).reshape(1, 6, 1)
print(A)

kernel = np.array([100, 10, 1], dtype=np.float32).reshape(3, 1, 1)

with tf.Session() as sess:
    print(sess.run(tf.nn.atrous_conv1d(A, kernel, padding='SAME', rate=[2])))
B = np.arange(16, dtype=np.float32).reshape(1, 4, 4, 1)
kernel = np.array([1000, 100, 10, 1.0], dtype=np.float32).reshape(2, 2, 1, 1)

with tf.Session() as sess:
    a = sess.run(tf.nn.convolution(B, kernel, padding='SAME', dilation_rate=np.array([2, 2])))
    b = sess.run(tf.nn.atrous_conv2d(B, kernel, rate=2, padding='SAME'))
    print(np.allclose(a, b))
    print(a)
    print(b)
C = np.arange(4**3, dtype=np.float32).reshape(1, 4, 4, 4, 1)
kernel = (10**np.arange(8, 0, -1, dtype=np.float32)).reshape(2, 2, 2, 1, 1)

with tf.Session() as sess:
    a = sess.run(tf.nn.conv3d(C, kernel, strides=[1, 1,1,1,1], padding='SAME'))
    b = sess.run(tf.nn.atrous_conv3d(C, kernel, rate=1, padding='SAME'))
    print(np.allclose(a, b))
    print(a)
    print(b)

Also running atrous_conv2d unit tests to verify backward compatibility.

This commit makes following changes:
* deleted most atrous_conv2d code to reuse existing tf.nn.convolution function
* added atrous_conv1d and atrous_conv3d with similar API as atrous_conv2d
* Added support for variable rate per dimension, e.g. for atrous_conv2d
  rate=2, or rate=[2,1] does different things. Former is equal to
  rate=[2,2]. rate_i determines dilation_rate in dimension i.
* added strides support with same API as in tf.nn.convolution function

This commit makes more code deletions then additions. However
documentation per function makes it appear large.

Test Plan:

Some simple tests to verify I haven't screwed something up:

```
A = np.array([1, 2, 3, 4, 5, 6], dtype=np.float32).reshape(1, 6, 1)
print(A)

kernel = np.array([100, 10, 1], dtype=np.float32).reshape(3, 1, 1)

with tf.Session() as sess:
    print(sess.run(tf.nn.atrous_conv1d(A, kernel, padding='SAME', rate=[2])))
```

```
B = np.arange(16, dtype=np.float32).reshape(1, 4, 4, 1)
kernel = np.array([1000, 100, 10, 1.0], dtype=np.float32).reshape(2, 2, 1, 1)

with tf.Session() as sess:
    a = sess.run(tf.nn.convolution(B, kernel, padding='SAME', dilation_rate=np.array([2, 2])))
    b = sess.run(tf.nn.atrous_conv2d(B, kernel, rate=2, padding='SAME'))
    print(np.allclose(a, b))
    print(a)
    print(b)
```

```
C = np.arange(4**3, dtype=np.float32).reshape(1, 4, 4, 4, 1)
kernel = (10**np.arange(8, 0, -1, dtype=np.float32)).reshape(2, 2, 2, 1, 1)

with tf.Session() as sess:
    a = sess.run(tf.nn.conv3d(C, kernel, strides=[1, 1,1,1,1], padding='SAME'))
    b = sess.run(tf.nn.atrous_conv3d(C, kernel, rate=1, padding='SAME'))
    print(np.allclose(a, b))
    print(a)
    print(b)
```

Also running atrous_conv2d unit tests to verify backward compatibility.
@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

@vrv vrv requested a review from gpapan February 16, 2017 00:34
@vrv vrv self-assigned this Feb 16, 2017
@vrv vrv added awaiting review Pull request awaiting review stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Feb 16, 2017
@vrv vrv assigned gpapan and unassigned vrv Feb 16, 2017
@teamdandelion
Copy link
Contributor

@gpapan PTAL

@@ -903,7 +948,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
called `input stride` or `dilation`. Altrenatively it could be 3 element sequence denoting dilation rate in each dimension.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/Altrenatively/Alternatively

@@ -903,7 +948,7 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
called `input stride` or `dilation`. Altrenatively it could be 3 element sequence denoting dilation rate in each dimension.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All lines must have length at most 80 pixels.

# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate
return convolution(input=value, filter=filters, padding=padding, dilation_rate=np.broadcast_to(rate, (2, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I recall correctly, we had kept the legacy atrous_conv2 as a separate op for backwards compatibility with already exported graphs. Can one of the TF admins comment if this is needed or not?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The underlying BatchToSpace and SpaceToBatch ops are needed for backwards compatibility with exported graphs. Changing the Python wrappers (which change how new graphs are constructed) should not impact that.

the input values every `rate` points.
This is equivalent to convolving the input with a set of upsampled filters,
produced by inserting `rate - 1` zeros between two consecutive values of the
filters along the `height`, `width` and `depth` dimensions, hence the name atrous convolution or convolution with holes (the French word trous means holes in
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`. Altrenatively it could be 3 element sequence denoting dilation rate in each dimension.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long


"""
return convolution(input=input, filter=filter, padding=padding, dilation_rate=np.broadcast_to(rate, (3, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

# More padding so that rate divides the height and width of the input.
pad_bottom_extra = (rate - in_height % rate) % rate
pad_right_extra = (rate - in_width % rate) % rate
return convolution(input=value, filter=filters, padding=padding, dilation_rate=np.broadcast_to(rate, (2, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

padding is other than `'VALID'` or `'SAME'`.

"""
return convolution(input=input, filter=filter, padding=padding, dilation_rate=np.broadcast_to(rate, (1, )), strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line >80 chars long

Copy link

@gpapan gpapan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please address some minor comments?

@gpapan
Copy link

gpapan commented Mar 8, 2017

@jbms

@teamdandelion teamdandelion added the stat:awaiting response Status - Awaiting response from author label Mar 8, 2017
@jbms
Copy link

jbms commented Mar 8, 2017

Removing the duplicate implementation of atrous convolution in the atrous_conv2d function is a desirable change.

As far as adding atrous_conv1d and atrous_conv3d, I don't see the advantage --- better to just use the convolution function directly. atrous_conv2d existed prior to the unified tf.nn.convolution function, and is preserved for backwards compatibility.

@nmiculinic
Copy link
Contributor Author

As far as adding atrous_conv1d and atrous_conv3d, I don't see the advantage --- better to just use the convolution function directly. atrous_conv2d existed prior to the unified tf.nn.convolution function, and is preserved for backwards compatibility.

For both completeness and ease of use. When I was searching for atrous_conv1d I only found atrous_conv2d and assumed only that one is implemented (and I implemented one myself mimicking atrous_conv1d implementation before realizing there's general convolution which handles all cases). Having atrous_conv2d as standalone increase its salience and confuses users (me included).

@jbms
Copy link

jbms commented Mar 13, 2017

I think this would be better addressed by adding a note to the documentation of atrous_conv2d explaining that it is a legacy interface and that tf.nn.convolution can be used for atrous convolutions of any number of dimensions, rather than adding additional redundant functions.

@ahundt
Copy link
Contributor

ahundt commented Mar 13, 2017

Some related discussion is going on in #4742, part of which I'll repost here bc it is pertinent.

conv2d_same in tensorflow/models provides both input tensor size and output tensor size:

def conv2d_same(inputs, num_outputs, kernel_size, stride, rate=1, scope=None):
  """[snip...]
  Args:
    inputs: A 4-D tensor of size [batch, height_in, width_in, channels].
  [snip...]
  Returns:
    output: A 4-D tensor of size [batch, height_out, width_out, channels] with
      the convolution output.
  """

I think it would be very productive to add a note in each appropriate function explaining what the output dimensions would be relative to given input dimensions in as they vary by configuration.

@nmiculinic
Copy link
Contributor Author

think this would be better addressed by adding a note to the documentation of atrous_conv2d explaining that it is a legacy interface and that tf.nn.convolution can be used for atrous convolutions of any number of dimensions, rather than adding additional redundant functions.

What is the protocol for function deprecation? Is there some annotation, some specific string in documentation or what?

If I understand correctly, you'd like deleting atrous 1d and 3d, and editing 2d with documentation changes

@@ -903,7 +950,8 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
the `height` and `width` dimensions. Equivalently, the rate by which we
upsample the filter values by inserting zeros across the `height` and
`width` dimensions. In the literature, the same parameter is sometimes
called `input stride` or `dilation`.
called `input stride` or `dilation`. Alternatively it could be 3 element
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be "2 element".

@jbms
Copy link

jbms commented Mar 16, 2017

Yes, remove atrous 1d and 3d.

There is the @deprecated decorator in tensorflow/python/util/deprecation.py. However, since there isn't yet a specific plan/date for removing the atrous_conv2d function, you shouldn't use that in this case.

Instead, you could just add a note to the top of the atrous_conv2d, such as:

This function is a simpler wrapper around the more general @{tf.nn.convolution}, and exists only for backwards compatibility. You can use @{tf.nn.convolution} to perform 1-D, 2-D, or 3-D atrous convolution.

@martinwicke martinwicke removed awaiting review Pull request awaiting review stat:awaiting tensorflower Status - Awaiting response from tensorflower labels Mar 22, 2017
Deleted atrous_conv1d and atrous_conv3d.
Added note in documentation for atrous_conv2d that convolution should be
used instead.
@@ -31,7 +31,9 @@
@@depthwise_conv2d
@@depthwise_conv2d_native
@@separable_conv2d
@@atrous_conv1d
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These should be removed.

@@ -805,9 +805,14 @@ def op(converted_input, _, converted_padding): # pylint: disable=missing-docstr
filter_shape=window_shape)


def atrous_conv2d(value, filters, rate, padding, name=None):
def atrous_conv2d(value, filters, rate, padding, strides=None, name=None):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not clear if it makes sense to add strides, since it is supported by the more generic convolution anyway, but if it is added, it should be documented in the Args section.

@drpngx
Copy link
Contributor

drpngx commented Apr 10, 2017

Jenkins, test this please.

@drpngx
Copy link
Contributor

drpngx commented Apr 10, 2017

@jbms @gpapan good to go?

@martinwicke
Copy link
Member

@jbms @gpapan Ping?

@martinwicke
Copy link
Member

@nmiculinic there are test failures, can you address those?

@martinwicke martinwicke added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Apr 26, 2017
return value
return convolution(input=value, filter=filters, padding=padding,
dilation_rate=np.broadcast_to(rate, (2, )),
strides=strides, name=name)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test failure is caused by reference to undefined variable strides.

@vrv
Copy link

vrv commented May 1, 2017

@jbms assuming tests pass, is this approved? Thanks!

@jbms
Copy link

jbms commented May 1, 2017

Yes looks good once tests pass.

@vrv
Copy link

vrv commented May 1, 2017

@tensorflow-jenkins test this please

@vrv
Copy link

vrv commented May 1, 2017

I made the fix for the @nmiculinic, let me know if this looks good (hit approve). Thanks!

@vrv vrv merged commit 880b4ac into tensorflow:master May 2, 2017
@vrv
Copy link

vrv commented May 2, 2017

I'm going to assume the build timeout is unrelated since no tests were added. Merigng.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes stat:awaiting response Status - Awaiting response from author stat:awaiting tensorflower Status - Awaiting response from tensorflower
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

10 participants