Skip to content

Commit

Permalink
Update on "Factor out numerical logic"
Browse files Browse the repository at this point in the history
This change is similar to #54049 in that it helps us factor out some code that can be used in both fast and slow versions of gradcheck.
 - `compute_gradient` and `compute_numerical_jacobian_cols` have  fewer responsibilities:
   - compute_numerical_jacobian_cols essentially only handles the complexity of complex derivatives
   - compute_gradient handles only finite differencing (and doesn't worry about different layouts and indexing into the input tensor)
  - we have two stages again where we first compute the columns separately, then combine them

[ghstack-poisoned]
  • Loading branch information
soulitzer committed Mar 31, 2021
2 parents 10e2278 + 1078402 commit d173b48
Showing 1 changed file with 15 additions and 13 deletions.
28 changes: 15 additions & 13 deletions torch/autograd/gradcheck.py
Expand Up @@ -141,11 +141,11 @@ def get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
target = inputs
inp_indices = [i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad]
for i, (inp, inp_idx) in enumerate(zip(iter_tensors(target, True), inp_indices)):
jacobians += [get_numerical_jacobian_for_input(fn, inp, inp_idx, inputs, outputs, eps, eps, grad_out)]
jacobians += [get_numerical_jacobian_wrt_specific_input(fn, inp, inp_idx, inputs, outputs, eps, eps, grad_out)]
return jacobians


def compute_gradient(fn, entry, v, norm_v, do_checks):
def compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
# Performs finite differencing by perturbing `entry` in-place by `v` and
# returns the gradient of each of the outputs wrt to x at idx.
# we currently assume that the norm of delta equals eps
Expand All @@ -160,7 +160,7 @@ def compute_gradient(fn, entry, v, norm_v, do_checks):
entry.copy_(orig)

def compute(a, b):
do_checks(a, b)
nbhd_checks_fn(a, b)
ret = (b - a) / (2 * norm_v)
return ret.detach().reshape(-1)

Expand Down Expand Up @@ -208,17 +208,19 @@ def combine_jacobian_cols(jacobians_cols, outputs, input, dim):
return jacobians


def prepped_input(input, input_idx, entry, entry_idx):
def prepped_input(input, maybe_perturbed_input):
# Prepares the inputs to be passed into the function while including the new modified input.
if input.layout == torch._mkldnn: # type: ignore # no attr _mkldnn
# Convert back to mkldnn
if input_idx == entry_idx:
return entry.to_mkldnn()
if maybe_perturbed_input is not None:
return maybe_perturbed_input.to_mkldnn()
else:
return input
elif input.layout == torch.sparse_coo:
# modifications to entry are reflected in input so we could've just returned `input` here
# but due to an issue with coalesce, we need to do an extra clone here.
# Modifications to entry are reflected in input so we could've just returned `input` here
# but there is an issue where calling .coalesce on a tensor moves it off the graph when the
# tensor is already coalesced, so analytical would always return 0 wrt to that input if it
# is previously used to compute forward pass. To get around this, we need to do an extra clone here.
# TODO: get rid of this extra clone once https://github.com/pytorch/pytorch/pull/52874 is landed
# Make this new tensor require again in case the function has hooks
return torch.sparse_coo_tensor(input._indices(), input._values(), input.size()).requires_grad_(True)
Expand All @@ -241,7 +243,7 @@ def check_outputs_same_dtype_and_shape_in_neighborhood(output1, output2, idx, de
f" dtypes {output1.dtype} and {output2.dtype}.")


def get_numerical_jacobian_for_input(fn, input, input_idx, inputs, outputs, delta, eps, grad_out):
def get_numerical_jacobian_wrt_specific_input(fn, input, input_idx, inputs, outputs, delta, eps, grad_out):
# Computes the numerical jacobians wrt to a single input. Returns N jacobian
# tensors, where N is the number of outputs. Input must require grad.
assert input.requires_grad
Expand All @@ -250,15 +252,15 @@ def get_numerical_jacobian_for_input(fn, input, input_idx, inputs, outputs, delt

for x, idx, d_idx in iter_tensor(input):
def wrapped_fn():
inp = tuple(prepped_input(a, i, x, input_idx) if is_tensor_like(a) else a
inp = tuple(prepped_input(a, x if i == input_idx else None) if is_tensor_like(a) else a
for i, a in enumerate(_as_tuple(inputs)))
return tuple(a.clone() for a in _as_tuple(fn(*inp)))

entry = x[idx]
do_checks = functools.partial(check_outputs_same_dtype_and_shape_in_neighborhood, idx=idx, delta=delta)
input_to_perturb = x[idx]
nbhd_checks_fn = functools.partial(check_outputs_same_dtype_and_shape_in_neighborhood, idx=idx, delta=delta)

def jvp_fn(delta):
return compute_gradient(wrapped_fn, entry, delta, eps, do_checks)
return compute_numerical_gradient(wrapped_fn, input_to_perturb, delta, eps, nbhd_checks_fn)
jacobian_cols[d_idx] = []
compute_numerical_jacobian_cols(jacobian_cols[d_idx], delta, jvp_fn, x.is_complex(), grad_out)

Expand Down

0 comments on commit d173b48

Please sign in to comment.