Skip to content

Commit

Permalink
Update on "Factor out numerical logic"
Browse files Browse the repository at this point in the history
This change is similar to #54049 in that it helps us factor out some code that can be used in both fast and slow versions of gradcheck.
 - `compute_gradient` and `compute_numerical_jacobian_cols` have  fewer responsibilities:
   - compute_numerical_jacobian_cols essentially only handles the complexity of complex derivatives
   - compute_gradient handles only finite differencing (and doesn't worry about different layouts and indexing into the input tensor)
  - we have two stages again where we first compute the columns separately, then combine them

[ghstack-poisoned]
  • Loading branch information
soulitzer committed Mar 26, 2021
2 parents 34b8a96 + e845cdc commit 7e19e2a
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
17 changes: 17 additions & 0 deletions test/test_autograd.py
Expand Up @@ -4218,6 +4218,23 @@ def fn2(x, y):
c = torch.rand(10, dtype=torch.float32).to_mkldnn().requires_grad_(True)
self.assertTrue(gradcheck(fn, (a, c), atol=1e-1, check_batched_grad=False))

def test_gradcheck_output_shape_or_dtype_depend_on_values(self):
def fn(x):
if torch.all(x >= 1):
return torch.cat([x, x])
else:
return x
a = torch.ones(1, requires_grad=True)
with self.assertRaisesRegex(AssertionError, 'return outputs with the same shape when inputs are perturbed'):
self.assertTrue(gradcheck(fn, (a,)))
def fn2(x):
if torch.all(x >= 1):
return x.to(torch.float32)
else:
return x
with self.assertRaisesRegex(AssertionError, 'return outputs with the same dtype when inputs are perturbed'):
self.assertTrue(gradcheck(fn2, (a,)))

def test_version_counter(self):
x = torch.randn(1, 2)

Expand Down
19 changes: 17 additions & 2 deletions torch/autograd/gradcheck.py
Expand Up @@ -145,7 +145,7 @@ def get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
return jacobians


def compute_gradient(fn, entry, v, norm_v):
def compute_gradient(fn, entry, v, norm_v, do_checks):
# Performs finite differencing by perturbing `entry` in-place by `v` and
# returns the gradient of each of the outputs wrt to x at idx.
# we currently assume that the norm of delta equals eps
Expand All @@ -160,6 +160,7 @@ def compute_gradient(fn, entry, v, norm_v):
entry.copy_(orig)

def compute(a, b):
do_checks(a, b)
ret = (b - a) / (2 * norm_v)
return ret.detach().reshape(-1)

Expand Down Expand Up @@ -227,6 +228,19 @@ def prepped_input(input, input_idx, entry, entry_idx):
return input


def check_outputs_same_dtype_and_shape_in_neighborhood(output1, output2, idx, delta):
# Check that the returned outputs don't have different dtype or shape when you
# perturb the input
assert output1.shape == output2.shape, \
(f"Expected `func` to return outputs with the same shape"
f" when inputs are perturbed on index {idx} by {delta}, but got:"
f" shapes {output1.shape} and {output2.shape}.")
assert output1.dtype == output2.dtype, \
(f"Expected `func` to return outputs with the same dtype"
f" when inputs are perturbed on index {idx} by {delta}, but got:"
f" dtypes {output1.dtype} and {output2.dtype}.")


def get_numerical_jacobian_for_input(fn, input, input_idx, inputs, outputs, delta, eps, grad_out):
# Computes the numerical jacobians wrt to a single input. Returns N jacobian
# tensors, where N is the number of outputs. Input must require grad.
Expand All @@ -241,9 +255,10 @@ def wrapped_fn():
return tuple(a.clone() for a in _as_tuple(fn(*inp)))

entry = x[idx]
do_checks = functools.partial(check_outputs_same_dtype_and_shape_in_neighborhood, idx=idx, delta=delta)

def jvp_fn(delta):
return compute_gradient(wrapped_fn, entry, delta, eps)
return compute_gradient(wrapped_fn, entry, delta, eps, do_checks)
jacobian_cols[d_idx] = []
compute_numerical_jacobian_cols(jacobian_cols[d_idx], delta, jvp_fn, x.is_complex(), grad_out)

Expand Down

0 comments on commit 7e19e2a

Please sign in to comment.