Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add atrous_conv2d_transpose python function #5300

Closed
wants to merge 8 commits into from
Closed

Add atrous_conv2d_transpose python function #5300

wants to merge 8 commits into from

Conversation

guotong1988
Copy link
Contributor

@guotong1988 guotong1988 commented Oct 31, 2016

Issue description : #4668

This is the first commit of the issue.
I will commit the unit test later.
I will also add the comment for the method.

Thank you for any advice.

Now I test it with:


# Input, output: [batch, height, width, depth]
x_image = tf.placeholder(tf.float32,shape=[1])
x = tf.reshape(x_image,[1,1,1,1])

#Filter: W [kernel_height, kernel_width, output_depth, input_depth]
W_cpu = np.array([[1,-1,1],[1,1,1],[-1,1,-1]],dtype=np.float32)
W = tf.Variable(W_cpu)
W = tf.reshape(W, [3,3,1,1])

strides=[1, 1, 1, 1]
padding='VALID'

y = tf.nn.atrous_conv2d_transpose(x, W, [1,5,5,1], 2, strides, padding)

x_data = np.array([1],dtype=np.float32)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    x = (sess.run(x, feed_dict={x_image: x_data}))
    W = (sess.run(W, feed_dict={x_image: x_data}))
    y = (sess.run(y, feed_dict={x_image: x_data}))

    print "The shape of x:\t", x.shape, ",\t and the x.reshape(1) is :"
    print x.reshape(1)
    print ""

    print "The shape of x:\t", W.shape, ",\t and the W.reshape(3,3) is :"
    print W.reshape(3,3)
    print ""

    print "The shape of y:\t", y.shape, ",\t and the y.reshape(5,5) is :"
    print y.reshape(5,5)
    print ""

and

  def testAtrousConv2DTransposeSingleStride(self):
    with self.test_session():
      strides = [1, 1, 1, 1]

      # Input, output: [batch, height, width, depth]
      x_shape = [2, 6, 4, 3]
      y_shape = [2, 6, 4, 2]

      # Filter: [kernel_height, kernel_width, output_depth, input_depth]
      f_shape = [3, 3, 2, 3]

      x = tf.constant(1.0, shape=x_shape, name="x", dtype=tf.float32)
      f = tf.constant(1.0, shape=f_shape, name="filter", dtype=tf.float32)

      output = tf.nn.atrous_conv2d_transpose(x, f, y_shape, 2, strides=strides,
                                      padding="SAME")

      value = output.eval()
      print(value)

@tensorflow-jenkins
Copy link
Collaborator

Can one of the admins verify this patch?

@mention-bot
Copy link

@guotong1988, thanks for your PR! By analyzing the history of the files in this pull request, we identified @tensorflower-gardener, @keveman and @yuefengz to be potential reviewers.

@vrv
Copy link

vrv commented Oct 31, 2016

Looks nice at a high level! Let's add the tests in the same PR.

@vrv vrv added stat:awaiting response Status - Awaiting response from author awaiting review Pull request awaiting review labels Oct 31, 2016
@vrv vrv changed the title Issue-4668 Add atrous_conv2d_transpose python function Oct 31, 2016
@guotong1988
Copy link
Contributor Author

I'm working.

Copy link

@gpapan gpapan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @guotong1988 . Please make sure the test to check that both the output tensor shape and values are correct.

@vrv vrv removed the awaiting review Pull request awaiting review label Nov 3, 2016
@sbrodehl
Copy link
Contributor

What's the current status here @guotong1988 ?
Need some help? What need's to be done to get this merged?

@guotong1988
Copy link
Contributor Author

guotong1988 commented Nov 29, 2016

I got a little busy to change my job in the past three weeks. And I just finished it.
I am working on it now.

@sbrodehl
Copy link
Contributor

Nice! Ping me if you need help.

@guotong1988
Copy link
Contributor Author

@sbrodehl In fact I'm not sure that my only test case can cover that much . Please provide more test case, if you have. Thank you .

@guotong1988
Copy link
Contributor Author

I write another example , and I think my commit may be right.

import numpy as np
import tensorflow as tf

# Input, output: [batch, height, width, depth]
x_image = tf.placeholder(tf.float32,shape=[2,2])
x = tf.reshape(x_image,[1,2,2,1])

#Filter: W [kernel_height, kernel_width, output_depth, input_depth]
W_cpu = np.array([[1,-1,1],[1,1,1],[-1,1,-1]],dtype=np.float32)
W = tf.Variable(W_cpu)
W = tf.reshape(W, [3,3,1,1])

strides=[1, 1, 1, 1]
padding='VALID'

y = tf.nn.atrous_conv2d_transpose(x, W, [1,5,5,1], 2, strides, padding)

x_data = np.array([[1,1],[1,1]],dtype=np.float32)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    x = (sess.run(x, feed_dict={x_image: x_data}))
    W = (sess.run(W, feed_dict={x_image: x_data}))
    y = (sess.run(y, feed_dict={x_image: x_data}))

    print "The shape of x:\t", x.shape, ",\t and the x.reshape(2,2) is :"
    print x.reshape([2,2])
    print ""

    print "The shape of x:\t", W.shape, ",\t and the W.reshape(3,3) is :"
    print W.reshape(3,3)
    print ""

    print "The shape of y:\t", y.shape, ",\t and the y.reshape(6,6) is :"
    print y.reshape(6,6)
    print ""

@guotong1988
Copy link
Contributor Author

guotong1988 commented Dec 1, 2016

Now that I'm sure that my commit is right , because I write this example which is similar to the above one.

import numpy as np
import tensorflow as tf

# Input, output: [batch, height, width, depth]
x_image = tf.placeholder(tf.float32,shape=[2,3])
x = tf.reshape(x_image,[1,2,3,1])

#Filter: W [kernel_height, kernel_width, output_depth, input_depth]
W_cpu = np.array([[1,-1,1],[1,1,1],[-1,1,-1]],dtype=np.float32)
W = tf.Variable(W_cpu)
W = tf.reshape(W, [3,3,1,1])

strides=[1, 1, 1, 1]
padding='VALID'

y = tf.nn.atrous_conv2d_transpose(x, W, [1,5,5,1], 2, strides, padding)

x_data = np.array([[1,1,1],[1,1,1]],dtype=np.float32)
with tf.Session() as sess:
    init = tf.initialize_all_variables()
    sess.run(init)

    x = (sess.run(x, feed_dict={x_image: x_data}))
    W = (sess.run(W, feed_dict={x_image: x_data}))
    y = (sess.run(y, feed_dict={x_image: x_data}))

    print "The shape of x:\t", x.shape, ",\t and the x.reshape(2,2) is :"
    print x.reshape([2,3])
    print ""

    print "The shape of x:\t", W.shape, ",\t and the W.reshape(3,3) is :"
    print W.reshape(3,3)
    print ""

    print "The shape of y:\t", y.shape, ",\t and the y.reshape(6,7) is :"
    print y.reshape(6,7)
    print ""

@guotong1988
Copy link
Contributor Author

guotong1988 commented Dec 1, 2016

Then I think I could start to finish the unit test.

@Fenugreek
Copy link
Contributor

I tried running your code on MNIST. The output shape looks wrong to me:

        print(inputs.get_shape(), W.get_shape())
        outputs = atrous_conv2d_transpose(inputs, W, [100, 28, 28, 1],
                                          2, [1,1,1,1], padding='SAME')
        print(outputs.get_shape())

gives:

(100, 28, 28, 8) (5, 5, 1, 8)
(100, 36, 36, 1)

I was expecting the output shape, i.e. the last line above, to read (100, 28, 28, 1).

@guotong1988
Copy link
Contributor Author

@Fenugreek Yes , I fixed it.

@vrv
Copy link

vrv commented Dec 1, 2016

@gpapan just sent me a full PR for implementing the transpose, btw.

@guotong1988
Copy link
Contributor Author

@vrv I will finish the unit test.

@vrv
Copy link

vrv commented Dec 2, 2016

@gpapan, since you wrote the internal one, I'll let you decide which one we should accept :)

(based on correctness / merits, of course.)

@guotong1988
Copy link
Contributor Author

I nearly finish. But the code need to review and edit, of course.

benoitsteiner pushed a commit to benoitsteiner/tensorflow that referenced this pull request Dec 3, 2016
@guotong1988
Copy link
Contributor Author

Thank you for your code. I learn a lot. Close this PR please.

@gpapan
Copy link

gpapan commented Dec 7, 2016 via email

@vrv vrv closed this Dec 7, 2016
copybara-service bot pushed a commit that referenced this pull request Sep 18, 2023
Imported from GitHub PR openxla/xla#5300

This is a new GPU SPMD optimization pass for the following pattern:
binary-op(all-gather(a), all-gather(b))
to
all-gather(binary-op(a, b))

Copybara import of the project:

--
77aafc0686fb98a6e13b6664ee537ed3cde5e24f by kushanam <kahmadian@nvidia.com>:

adding a new pass to optimize reduce_scatter->all_gather->binary_op sequence

--
0b1e8eb599f8a7334b7c9826746db67e0923f2f7 by kushanam <kahmadian@nvidia.com>:

applying review refactors

--
9b181ec7487e7ded4610a779f8929d2e2a199e0d by kushanam <kahmadian@nvidia.com>:

removing reduce-scatter from the all-gather optimization

--
a8c49eb58f3b370627cd57c62f456696567ba60a by kushanam <kahmadian@nvidia.com>:

remove traversal all-gather search and rely on immediate parent

--
d90f5a148bc099455724450b84f1af8fb83ffc66 by kushanam <kahmadian@nvidia.com>:

remove extra gpu word from the directive

Merging this change closes #5300

PiperOrigin-RevId: 566298114
copybara-service bot pushed a commit that referenced this pull request Sep 18, 2023
PR #5300: A new pass to optimize the AllGather->Binary_Op order sequence

Imported from GitHub PR openxla/xla#5300

This is a new GPU SPMD optimization pass for the following pattern:
binary-op(all-gather(a), all-gather(b))
to
all-gather(binary-op(a, b))

PiperOrigin-RevId: 566340142
copybara-service bot pushed a commit that referenced this pull request Sep 29, 2023
Imported from GitHub PR openxla/xla#5300

This is a new GPU SPMD optimization pass for the following pattern:
binary-op(all-gather(a), all-gather(b))
to
all-gather(binary-op(a, b))

Copybara import of the project:

--
198c4b2b8b8c155b50a5643e960366bdb51aece0 by kushanam <kahmadian@nvidia.com>:

adding a new pass to optimize reduce_scatter->all_gather->binary_op sequence

--
8f8cc822229f1c6a54c969188240d5b2a421e9ee by kushanam <kahmadian@nvidia.com>:

applying review refactors

--
1beffbb5ecc007aa729b2c20e39ec95a00a73fd8 by kushanam <kahmadian@nvidia.com>:

removing reduce-scatter from the all-gather optimization

--
993f3d66e1a75b774e39aa3a55134f841d132df5 by kushanam <kahmadian@nvidia.com>:

remove traversal all-gather search and rely on immediate parent

--
c7a3bea5846220dc49ae0087039e5ad77fd308c7 by kushanam <kahmadian@nvidia.com>:

remove extra gpu word from the directive

--
2b9afd129a7759ba4e59df461b9d8d06033f7649 by kushanam <kahmadian@nvidia.com>:

fixing for disabled SPMD partitioning

--
ee412cbbbb65f322d337dd522c17414a3c6afbd6 by kushanam <kahmadian@nvidia.com>:

diferring node removal and fixing the corresponding tests

Merging this change closes #5300

PiperOrigin-RevId: 569445810
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes stat:awaiting response Status - Awaiting response from author
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants