Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added atrous_conv1d and 3d. Refactored 2d. (#7545)
* Added atrous_conv1d and 3d. Refactored 2d. This commit makes following changes: * deleted most atrous_conv2d code to reuse existing tf.nn.convolution function * added atrous_conv1d and atrous_conv3d with similar API as atrous_conv2d * Added support for variable rate per dimension, e.g. for atrous_conv2d rate=2, or rate=[2,1] does different things. Former is equal to rate=[2,2]. rate_i determines dilation_rate in dimension i. * added strides support with same API as in tf.nn.convolution function This commit makes more code deletions then additions. However documentation per function makes it appear large. Test Plan: Some simple tests to verify I haven't screwed something up: ``` A = np.array([1, 2, 3, 4, 5, 6], dtype=np.float32).reshape(1, 6, 1) print(A) kernel = np.array([100, 10, 1], dtype=np.float32).reshape(3, 1, 1) with tf.Session() as sess: print(sess.run(tf.nn.atrous_conv1d(A, kernel, padding='SAME', rate=[2]))) ``` ``` B = np.arange(16, dtype=np.float32).reshape(1, 4, 4, 1) kernel = np.array([1000, 100, 10, 1.0], dtype=np.float32).reshape(2, 2, 1, 1) with tf.Session() as sess: a = sess.run(tf.nn.convolution(B, kernel, padding='SAME', dilation_rate=np.array([2, 2]))) b = sess.run(tf.nn.atrous_conv2d(B, kernel, rate=2, padding='SAME')) print(np.allclose(a, b)) print(a) print(b) ``` ``` C = np.arange(4**3, dtype=np.float32).reshape(1, 4, 4, 4, 1) kernel = (10**np.arange(8, 0, -1, dtype=np.float32)).reshape(2, 2, 2, 1, 1) with tf.Session() as sess: a = sess.run(tf.nn.conv3d(C, kernel, strides=[1, 1,1,1,1], padding='SAME')) b = sess.run(tf.nn.atrous_conv3d(C, kernel, rate=1, padding='SAME')) print(np.allclose(a, b)) print(a) print(b) ``` Also running atrous_conv2d unit tests to verify backward compatibility. * Fixed > 80 char lines, implemented suggestions * Implemented suggestions. Deleted atrous_conv1d and atrous_conv3d. Added note in documentation for atrous_conv2d that convolution should be used instead. * Restored short summary in atrous_conv2d * Fixed cleanup of leftover comments/mentions of atrous_conv1d/3d. * Fix undefined strides
- Loading branch information