Skip to content

Commit

Permalink
Update on "Factor out numerical logic"
Browse files Browse the repository at this point in the history
This change is similar to #54049 in that it helps us factor out some code that can be used in both fast and slow versions of gradcheck.
 - `compute_gradient` and `compute_numerical_jacobian_cols` have  fewer responsibilities:
   - compute_numerical_jacobian_cols essentially only handles the complexity of complex derivatives
   - compute_gradient handles only finite differencing (and doesn't worry about different layouts and indexing into the input tensor)
  - we have two stages again where we first compute the columns separately, then combine them

[ghstack-poisoned]
  • Loading branch information
soulitzer committed Apr 6, 2021
2 parents 863da78 + 06cfece commit b03cf2c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
39 changes: 34 additions & 5 deletions torch/autograd/gradcheck.py
Expand Up @@ -114,7 +114,7 @@ def get_stride(size):
yield x_tensor, x_idx, d_idx


def get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
def _get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
grad_out=1.0) -> List[Tuple[torch.Tensor, ...]]:
"""Computes the numerical jacobian for a given fn and inputs. Returns M * N jacobians
where M is the number of input tensors that require grad, and N is the number of output
Expand All @@ -136,7 +136,7 @@ def get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
"""
jacobians: List[Tuple[torch.Tensor, ...]] = []
if outputs is None:
outputs = _as_tuple(fn(inputs))
outputs = _as_tuple(fn(*_as_tuple(inputs)))
if target is None:
target = inputs
inp_indices = [i for i, a in enumerate(target) if is_tensor_like(a) and a.requires_grad]
Expand All @@ -145,6 +145,15 @@ def get_numerical_jacobian(fn, inputs, outputs=None, target=None, eps=1e-3,
return jacobians


def get_numerical_jacobian(fn, inputs, target=None, eps=1e-3, grad_out=1.0):
# Simple wrapper around _get_numerical_jacobian
warnings.warn("get_analytical_jacobian is deprecated!")
def fn_pack_inps(*inps):
return fn(inps)
jacobians = _get_numerical_jacobian(fn_pack_inps, inputs, None, target, eps, grad_out)
return jacobians[0][0]


def compute_numerical_gradient(fn, entry, v, norm_v, nbhd_checks_fn):
# Performs finite differencing by perturbing `entry` in-place by `v` and
# returns the gradient of each of the outputs wrt to x at idx.
Expand Down Expand Up @@ -343,6 +352,26 @@ def fail_test(msg):
return jacobians, failed


def get_analytical_jacobian(inputs, output, nondet_tol=0.0, grad_out=1.0):
# Replicates the behavior of the old get_analytical_jacobian before the refactor
warnings.warn("get_analytical_jacobian is deprecated!")

diff_input_list = list(iter_tensors(inputs, True))
def backward_fn(grad_output):
return torch.autograd.grad(output, diff_input_list, grad_output,
retain_graph=True, allow_unused=True)

jacobians_rows = compute_analytical_jacobian_rows(backward_fn, output.clone(), grad_out)
jacobians_rows_reentrant = compute_analytical_jacobian_rows(backward_fn, output.clone(), grad_out)

output_numel = output.numel()
jacobians, correct_grad_types, correct_grad_sizes = combine_jacobian_rows(jacobians_rows, inputs, output_numel)
jacobians_reentrant, _, _ = combine_jacobian_rows(jacobians_rows_reentrant, inputs, output_numel)
reentrant = check_jacobians_equal(jacobians, jacobians_reentrant, nondet_tol)

return jacobians, reentrant, correct_grad_sizes, correct_grad_types


def compute_analytical_jacobian_rows(vjp_fn, sample_output, grad_out_scale) -> List[List[Optional[torch.Tensor]]]:
# Computes Jacobian row-by-row using backward function `vjp_fn` = v^T J
# NB: this function does not assume vjp_fn(v) to return tensors with
Expand Down Expand Up @@ -410,7 +439,7 @@ def check_outputs(outputs) -> None:
def check_no_differentiable_outputs(fail_test, func, inputs, func_out, eps) -> bool:
# When there are no differentiable outputs, numerical gradient for a function is
# expected to be zero.
jacobians_all_inputs_outputs = get_numerical_jacobian(func, inputs, func_out, eps=eps)
jacobians_all_inputs_outputs = _get_numerical_jacobian(func, inputs, func_out, eps=eps)
for jacobians_all_outputs_and_fixed_input in jacobians_all_inputs_outputs:
for jacobian in jacobians_all_outputs_and_fixed_input:
if torch.ne(jacobian, 0).sum() > 0:
Expand Down Expand Up @@ -695,9 +724,9 @@ def fail_test(msg):
if not outputs:
return check_no_differentiable_outputs(fail_test, func, tupled_inputs, _as_tuple(func_out), eps)

numerical = transpose(get_numerical_jacobian(func, tupled_inputs, outputs, eps=eps))
numerical = transpose(_get_numerical_jacobian(func, tupled_inputs, outputs, eps=eps))
if any(isinstance(o, torch.Tensor) and o.is_complex() for o in _as_tuple(func_out)):
numerical_from_imag_grad_out = transpose(get_numerical_jacobian(func, tupled_inputs, outputs, eps=eps, grad_out=1j))
numerical_from_imag_grad_out = transpose(_get_numerical_jacobian(func, tupled_inputs, outputs, eps=eps, grad_out=1j))

for i, o in enumerate(outputs):
analytical, failed = check_analytical_jacobian_attributes(tupled_inputs, o, nondet_tol, 1.0,
Expand Down
6 changes: 3 additions & 3 deletions torch/testing/_internal/common_nn.py
Expand Up @@ -19,7 +19,7 @@
TEST_WITH_ROCM, gradcheck, gradgradcheck
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import expectedAlertNondeterministic
from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors
from torch.autograd.gradcheck import _get_numerical_jacobian, iter_tensors
from torch.autograd import Variable
from torch.types import _TensorOrTensors
import torch.backends.cudnn
Expand Down Expand Up @@ -5012,12 +5012,12 @@ def fw(*input):

res: Tuple[torch.Tensor, ...] = tuple()
if jacobian_input:
res += get_numerical_jacobian(fw, input, eps=1e-6),
res += _get_numerical_jacobian(fw, input, eps=1e-6),
if jacobian_parameters:
param, _ = self._get_parameters(module)
to_cat = []
for p in param:
jacobian = get_numerical_jacobian(fw, input, target=p, eps=1e-6)
jacobian = _get_numerical_jacobian(fw, input, target=p, eps=1e-6)
# get_numerical_jacobian returns a list of tuples but we require a tensor
to_cat.append(jacobian[0][0])
res += (torch.cat(to_cat, 0),)
Expand Down

0 comments on commit b03cf2c

Please sign in to comment.