New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch support for various image_ops #14854

Merged
merged 22 commits into from Dec 7, 2017

Conversation

Projects
None yet
7 participants
@JoshVarty
Copy link
Contributor

JoshVarty commented Nov 24, 2017

Working on #8926
I used #7369 as a guide for my work here.

I have added batch support for:

  • flip_left_right
  • flip_up_down
  • random_flip_left_right
  • random_flip_up_down
  • transpose_image
  • rot90

I have corrected existing tests in image_ops_test.py and introduced a number of new tests based on existing tests for 3D inputs.

This is my first contribution to this repository and I have tried to follow the contributing guidelines. However, running pylint on image_ops_impl.py and image_ops_test.py revealed a number of pre-existing style violations. I've tried to fix the ones relevant to my work but may have missed some.

@googlebot googlebot added the cla: yes label Nov 24, 2017

@tensorflow-jenkins

This comment has been minimized.

Copy link
Collaborator

tensorflow-jenkins commented Nov 24, 2017

Can one of the admins verify this patch?

@JoshVarty

This comment has been minimized.

Copy link
Contributor

JoshVarty commented Nov 26, 2017

Hmm, I can successfully run the sanity checks locally via:
tensorflow/tools/ci_build/ci_build.sh CPU tensorflow/tools/ci_build/ci_sanity.sh

Does anyone happen to know how I can get more information on why they're failing on CI? The relevant output says:

Step 3 : COPY install/*.sh /install/
tensorflow/tools/ci_build/ci_build.sh: line 130: 26052 Terminated              docker build -t ${DOCKER_IMG_NAME} -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
ERROR: docker build failed. Dockerfile is at /var/lib/jenkins/workspace/tensorflow-pull-requests-sanity/tensorflow/tools/ci_build/Dockerfile.cpu
Build was aborted
Aborted by unknown
Unable to get pull request builder trigger!!
Setting status of abcb89bce4ce4b68cf714e1cccbc13d4eb1309b7 to FAILURE with url https://ci.tensorflow.org/job/tensorflow-pull-requests-sanity/11268/ and message: 'FAILURE
 '
Using context: Sanity Checks
Finished: ABORTED
def fix_image_flip_shape(image, result):
"""Set the shape to 3 dimensional if we don't know anything else.
def _EnsureTensorIs4D(image):
"""Converts `image` to a 4-D Tensor if it is not already one

This comment has been minimized.

@drpngx

drpngx Dec 4, 2017

Member

Add period at end of sentence.

This comment has been minimized.

@drpngx

drpngx Dec 4, 2017

Member

Thank you for that function.

@drpngx drpngx requested a review from martinwicke Dec 4, 2017

@martinwicke
Copy link
Member

martinwicke left a comment

Looks great in general. I have some comments of medium importance. The biggest is: Can you refactor the code so that the common functionality is not repeated?

return fix_image_flip_shape(image, result)

result = array_ops.where(mirror_cond, x=image,
y=array_ops.reverse(image, [2]))

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

I'm wondering whether it makes sense to pull basically all this code out into a private function parameterized by the dimension to flip -- just for DRY reasons.

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

You could even make if

_flip_image(image, axis, random=False)

and get rid of all the code in the non-random versions as well. That would be nice for maintainability.

@@ -737,6 +737,15 @@ def testIdempotentLeftRight(self):
y_tf = y.eval()
self.assertAllEqual(y_tf, x_np)

def testIdempotentLeftRightWithBatch(self):

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

I don't think idempotent is the right word here (a function for which f^n(x) = f(x) for all n is idempotent). You are checking that flip_left_right is an involution (i.e. f(f(x)) = x).

dtype=np.uint8).reshape([2, 2, 3, 1])

with self.test_session(use_gpu=True):
x_tf = constant_op.constant(x_np, shape=x_np.shape).eval()

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

Can you set a seed here to make sure this isn't going to be flaky?

self.assertAllEqual(current_y_tf, current_y_np)
count_flipped += 1
self.assertGreaterEqual(count_flipped, 1)
self.assertGreaterEqual(count_unflipped, 1)

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

The number should be binomially distributed around 50, so the stddev of the normal approximation is 5. We should be pretty safe to assert that count_flipped > 20 and count_unflipped > 20. That's still 6 sigma.

count_flipped += 1
self.assertGreaterEqual(count_flipped, 1)
self.assertGreaterEqual(count_unflipped, 1)

def testIdempotentUpDown(self):

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

I see now that you're not the one introducing the idempotent issue, can you fix it here too?

self.assertAllEqual(current_y_tf, current_y_np)
count_flipped += 1
self.assertGreaterEqual(count_flipped, 1)
self.assertGreaterEqual(count_unflipped, 1)

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

Same comment, let's assert >20, and add a seed.

count_flipped += 1
self.assertGreaterEqual(count_flipped, 1)
self.assertGreaterEqual(count_unflipped, 1)

def testIdempotentTranspose(self):

This comment has been minimized.

@martinwicke

martinwicke Dec 4, 2017

Member

Argh idempotent.

JoshVarty added some commits Dec 5, 2017

@JoshVarty

This comment has been minimized.

Copy link
Contributor

JoshVarty commented Dec 5, 2017

Thanks for taking a look guys. @martinwicke Let me know if you spot any other opportunities for cleanup/refactoring.

@martinwicke
Copy link
Member

martinwicke left a comment

Awesome. So much less code. Thank you! Just minor things in the docstring (and one thing about the arguments to the private function).

Args:
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.

This comment has been minimized.

@martinwicke

martinwicke Dec 6, 2017

Member

I think you need to get rid of this blank line for the comment to render properly

axis: A Tensor. Must be one of the following types: int32, int64. 1-D.
The indices of the dimensions to reverse. Must be in the range
[-rank(tensor), rank(tensor))

This comment has been minimized.

@martinwicke

martinwicke Dec 6, 2017

Member

You'll need to document all arguments.

image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
axis: A Tensor. Must be one of the following types: int32, int64. 1-D.

This comment has been minimized.

@martinwicke

martinwicke Dec 6, 2017

Member

Why 1D? This should be a scalar?

I understand you can pass 1D to reverse, but unless you use that, I would keep it a scalar here, and pass [axis] to reverse.

If you want to keep it 1D, you should change the docstring to explain that each components needs to be in the range given, and possibly they have to be ordered? Not sure. Definitely adds complexity you don't really need.

This comment has been minimized.

@JoshVarty

JoshVarty Dec 6, 2017

Contributor

I agree that it should be a scalar. (I had originally just made it identical to array_ops.reverse() but that doesn't make as much sense here)

One strange thing about this refactor is that it forces the caller to provide an axis relative to a 4-D Tensor, even if they've passed in a 3-D image.

For example random_flip_up_down() needs to provide an axis, but doesn't know whether or not the image it's working with is 3-D or a 4-D batch of images. Currently I just pass in 1 with the knowledge that a single image will be extended to 4 dimensions.

Comment addressed.

JoshVarty added some commits Dec 6, 2017

Args:
image: 4-D Tensor of shape `[batch, height, width, channels]` or
3-D Tensor of shape `[height, width, channels]`.
axis: A Python integer representing the axis on which the image(s)

This comment has been minimized.

@JoshVarty

JoshVarty Dec 6, 2017

Contributor

As mentioned here:

One strange thing about this refactor is that it forces the caller to provide an axis relative to a 4-D Tensor, even if they've passed in a 3-D image.

For example random_flip_up_down() needs to provide an axis, but doesn't know whether or not the image it's working with is 3-D or a 4-D batch of images. Currently I just pass in 1 with the knowledge that a single image will be extended to 4 dimensions.

This comment has been minimized.

@martinwicke

martinwicke Dec 6, 2017

Member

That's a good point -- but I think that's ok. It's made clear in the argument description, and in the end, this is a private utility function, so I'm not too worried about it.

This comment has been minimized.

@martinwicke

martinwicke Dec 6, 2017

Member

I guess if you wanted to, you could make a symbolic argument: direction='UP_DOWN' or 'LEFT_RIGHT' and interpret that. But again, the function is private, I don't think it's worth it.

This comment has been minimized.

@JoshVarty

JoshVarty Dec 6, 2017

Contributor

👍

@kokoro-team kokoro-team removed the kokoro:run label Dec 6, 2017

@martinwicke martinwicke merged commit 20aa9e0 into tensorflow:master Dec 7, 2017

14 checks passed

Android Demo App Internal CI build successful
Details
GPU CC Internal CI build successful
Details
GPU Python3 Internal CI build successful
Details
MacOS Contrib Internal CI build successful
Details
MacOS Python2 and CC Internal CI build successful
Details
Sanity Checks SUCCESS
Details
Ubuntu CC Internal CI build successful
Details
Ubuntu Makefile Internal CI build successful
Details
Ubuntu Python2 Internal CI build successful
Details
Ubuntu Python3 Internal CI build successful
Details
Ubuntu Sanity Internal CI build successful
Details
Ubuntu contrib Internal CI build successful
Details
ci.tensorflow.org SUCCESS
Details
cla/google All necessary CLAs are signed

@JoshVarty JoshVarty deleted the JoshVarty:BatchImageOps branch Dec 7, 2017

@jhseu

This comment has been minimized.

Copy link
Member

jhseu commented Dec 13, 2017

Hey @JoshVarty, thanks for the contribution. Unfortunately, we have to revert this change because it turns out to be a bottleneck in our image models. There wasn't a reasonable way to fix the performance, because it's caused by both the conversion to 4-D and switching to tf.where. We'd want the 4-D version to perform reasonably well before we leave it in.

I added benchmarks here:
#15348

You can run it with bazel run -c opt image_ops_test -- --benchmarks=FlipImageBenchmark

jhseu added a commit to jhseu/tensorflow that referenced this pull request Dec 13, 2017

@JoshVarty

This comment has been minimized.

Copy link
Contributor

JoshVarty commented Dec 14, 2017

Thanks for the heads up guys. Presumably the other places where 3-D tensors are expanded to 4-D tensors would also be slow, correct? (Perhaps they're not used in your image models so it wasn't noticed until now)

It sounds like two fixes are needed:

  1. We shouldn't always expand 3-D tensors to 4-D.
  2. We need an alternative to tf.where when operating on 4-D tensors.

The first issue I think I know how to fix.

Do you guys have any hunches on how performance for 4-D batches might be fixed? Is there a performant alternative to tf.where I should experiment with? Would a new op for 4-D tensors make sense here?

jhseu added a commit that referenced this pull request Dec 14, 2017

@jhseu

This comment has been minimized.

Copy link
Member

jhseu commented Dec 14, 2017

Yeah, possibly other 3-D to 4-D changes are slower too, but this change was really noticeable in our internal benchmarks.

The tf.where issue is because we're flipping all images. With tf.cond on a 3-D Tensor, we only flip images as needed.

@JoshVarty

This comment has been minimized.

Copy link
Contributor

JoshVarty commented Dec 15, 2017

Would you guys be comfortable with a solution that used tf.cond when operating on 3-D tensors, but tf.where when operating on 4-D tensors? Or is it best to avoid tf.where altogether?

@jhseu

This comment has been minimized.

Copy link
Member

jhseu commented Dec 15, 2017

I'm comfortable with anything where the recently added benchmarks perform about the same before and after, including with 4-D tensors.

The issue is that we'd discourage users from using 4-D if it's significantly slower than 3-D on a per-image benchmark.

@jhseu

This comment has been minimized.

Copy link
Member

jhseu commented Dec 15, 2017

(Note the benchmark didn't include 4-D tests, but it's trivial to add).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment