Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix grad of variable size on extract_image_patches #29815

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 44 additions & 2 deletions tensorflow/python/kernel_tests/extract_image_patches_grad_test.py
Expand Up @@ -100,8 +100,6 @@ def testGradient(self):

err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)

print('extract_image_patches gradient err: %.4e' % err)
self.assertLess(err, 1e-4)

@test_util.run_deprecated_v1
Expand All @@ -124,6 +122,50 @@ def testConstructGradientWithLargeImages(self):
# Won't time out.
self.assertIsNotNone(gradients)

def _VariableShapeGradient(self, test_shape_pattern):
"""Use test_shape_pattern to infer which dimensions are of
variable size.
"""
# Set graph seed for determinism.
random_seed = 42
random_seed_lib.set_random_seed(random_seed)

with self.test_session():
for test_case in self._TEST_CASES:
np.random.seed(random_seed)
in_shape = test_case['in_shape']
test_shape = [x if x is None else y
for x, y in zip(test_shape_pattern, in_shape)]
in_val = array_ops.placeholder(shape=test_shape, dtype=dtypes.float32)

feed_dict = {in_val: np.random.random(in_shape)}
for padding in ['VALID', 'SAME']:
out_val = array_ops.extract_image_patches(in_val, test_case['ksizes'],
test_case['strides'],
test_case['rates'], padding)
out_val_tmp = out_val.eval(feed_dict=feed_dict)
out_shape = out_val_tmp.shape

err = gradient_checker.compute_gradient_error(in_val, in_shape,
out_val, out_shape)
self.assertLess(err, 1e-4)

@test_util.run_deprecated_v1
def test_BxxC_Gradient(self):
self._VariableShapeGradient([-1, None, None, -1])

@test_util.run_deprecated_v1
def test_xHWx_Gradient(self):
self._VariableShapeGradient([None, -1, -1, None])

@test_util.run_deprecated_v1
def test_BHWC_Gradient(self):
self._VariableShapeGradient([-1, -1, -1, -1])

@test_util.run_deprecated_v1
def test_AllNone_Gradient(self):
self._VariableShapeGradient([None, None, None, None])


if __name__ == '__main__':
test.main()
12 changes: 5 additions & 7 deletions tensorflow/python/ops/array_grad.py
Expand Up @@ -831,12 +831,9 @@ def _QuantizeAndDequantizeV3Grad(_, grad):

@ops.RegisterGradient("ExtractImagePatches")
def _ExtractImagePatchesGrad(op, grad):
batch_size, rows_in, cols_in, channels = [
dim.value for dim in op.inputs[0].shape.dims
]
input_bhwc = array_ops.shape(op.inputs[0])
batch_size = input_bhwc[0]
channels = input_bhwc[3]
input_bhwc = array_ops.shape(op.inputs[0], out_type=dtypes.int64)
batch_size, rows_in, cols_in, channels = input_bhwc[0], input_bhwc[1], \
input_bhwc[2], input_bhwc[3]

# Create indices matrix for input tensor.
# Note that 0 is preserved for padding location,
Expand All @@ -853,7 +850,8 @@ def _ExtractImagePatchesGrad(op, grad):
op.get_attr("padding"))

# Create indices matrix for output tensor.
_, rows_out, cols_out, _ = [dim.value for dim in op.outputs[0].shape.dims]
output_bhwc = array_ops.shape(op.outputs[0], out_type=dtypes.int64)
rows_out, cols_out = output_bhwc[1], output_bhwc[2]
_, ksize_r, ksize_c, _ = op.get_attr("ksizes")
# Indices for output start from 0.
output_indices_num = rows_out * cols_out * ksize_r * ksize_c
Expand Down