diff --git a/test/common_nn.py b/test/common_nn.py index 51b883ba9ea80..48a521129601b 100644 --- a/test/common_nn.py +++ b/test/common_nn.py @@ -7,7 +7,7 @@ import torch import torch.cuda from common import TestCase, to_gpu, freeze_rng_state, is_iterable -from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors, contiguous +from torch.autograd.gradcheck import get_numerical_jacobian, iter_tensors import torch.backends.cudnn # tarfile module tries to obtain a file object name in python 3.3 @@ -783,9 +783,8 @@ def fw(input): return self._forward(module, input).detach() res = tuple() - input = contiguous(input) if jacobian_input: - res += get_numerical_jacobian(fw, input, input, eps=1e-6), + res += get_numerical_jacobian(fw, input, eps=1e-6), if jacobian_parameters: param, _ = self._get_parameters(module) res += torch.cat([get_numerical_jacobian(fw, input, p, eps=1e-6) for p in param], 0), @@ -813,8 +812,8 @@ def check_criterion_jacobian(self, criterion, input, target): input_t = iter_tensors(input) numerical_t = iter_tensors(numerical_d_x) for x, d_x in zip(input_t, numerical_t): - x = x.view(-1) - d_x = d_x.view(-1) + x = x.view(-1).data + d_x = d_x.view(-1).data for i in range(x.nelement()): original = x[i].item() x[i] = original + eps diff --git a/test/test_distributions.py b/test/test_distributions.py index 4099c57a0d153..6b95d79f0b09e 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -880,7 +880,6 @@ def test_multinomial_2d(self): self.assertEqual(Multinomial(total_count, p).sample((6,)).size(), (6, 2, 3)) set_rng_seed(0) self._gradcheck_log_prob(lambda p: Multinomial(total_count, p), [p]) - p.grad.zero_() self._gradcheck_log_prob(lambda p: Multinomial(total_count, None, p.log()), [p]) # sample check for extreme value of probs diff --git a/test/test_legacy_nn.py b/test/test_legacy_nn.py index 02b4b791934a0..63833cbaa8879 100644 --- a/test/test_legacy_nn.py +++ b/test/test_legacy_nn.py @@ -8,7 +8,7 @@ import torch.legacy.nn as nn from common_nn import NNTestCase, ModuleTest, CriterionTest, iter_tensors, \ module_tests, criterion_tests, TEST_CUDA, PRECISION -from torch.autograd.gradcheck import get_numerical_jacobian, contiguous +from torch.autograd.gradcheck import get_numerical_jacobian from common import to_gpu, freeze_rng_state, run_tests from torch.autograd import Variable @@ -661,10 +661,9 @@ def fw(input): return out res = tuple() - input = contiguous(input) if jacobian_input: input = require_grad(input) - res += get_numerical_jacobian(fw, input, input, eps=1e-6), + res += get_numerical_jacobian(fw, input, eps=1e-6), if jacobian_parameters: params, _ = self._get_parameters(module) jacobians = [] diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 2041c41208fcd..c4bdf58a5af10 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -852,7 +852,7 @@ std::tuple prelu_double_backward( return std::tuple( ggO, ggW_maybe_squeeze.expand_as(gO) * gO * nonpositive_mask, - (ggI * gO * nonpositive_mask).sum() + (ggI * gO * nonpositive_mask).sum().expand_as(weight) ); } else { // Expand ggW to match size of ggI; a simple expand doesn't work because diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 49f1180c1191b..972f2656364a2 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -2,6 +2,7 @@ from collections import Iterable import torch.testing import sys +from itertools import product def zero_gradients(x): @@ -34,35 +35,21 @@ def make_jacobian(input, num_out): def iter_tensors(x, only_requiring_grad=False): if isinstance(x, torch.Tensor): if x.requires_grad or not only_requiring_grad: - yield x.data + yield x elif isinstance(x, Iterable): for elem in x: for result in iter_tensors(elem, only_requiring_grad): yield result -def iter_tensors_with_grad(x): - if isinstance(x, torch.Tensor): - if x.requires_grad: - yield (x.grad.data, x.data) if x.grad is not None else (None, None) - elif isinstance(x, Iterable): - for elem in x: - for result in iter_tensors_with_grad(elem): - yield result - - -def contiguous(input): - if isinstance(input, torch.Tensor): - return input.contiguous() - elif isinstance(input, Iterable): - return type(input)(contiguous(e) for e in input) - return input - - -def get_numerical_jacobian(fn, input, target, eps=1e-3): - # To be able to use .view(-1) input must be contiguous - input = contiguous(input) - target = contiguous(target) +# `input` is input to `fn` +# `target` is the Tensors wrt whom Jacobians are calculated (default=`input`) +# +# Note that `target` may not even be part of `input` to `fn`, so please be +# **very careful** in this to not clone `target`. +def get_numerical_jacobian(fn, input, target=None, eps=1e-3): + if target is None: + target = input output_size = fn(input).numel() jacobian = make_jacobian(target, output_size) @@ -74,23 +61,25 @@ def get_numerical_jacobian(fn, input, target, eps=1e-3): # TODO: compare structure for x_tensor, d_tensor in zip(x_tensors, j_tensors): - flat_tensor = x_tensor.view(-1).detach() - for i in range(flat_tensor.nelement()): - orig = flat_tensor[i].item() - flat_tensor[i] = orig - eps + # need data here to get around the version check because without .data, + # the following code updates version but doesn't change content + x_tensor = x_tensor.data + for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): + orig = x_tensor[x_idx].item() + x_tensor[x_idx] = orig - eps outa = fn(input).clone() - flat_tensor[i] = orig + eps + x_tensor[x_idx] = orig + eps outb = fn(input).clone() - flat_tensor[i] = orig + x_tensor[x_idx] = orig r = (outb - outa) / (2 * eps) - d_tensor[i] = r.detach().contiguous().view(-1) + d_tensor[d_idx] = r.detach().reshape(-1) return jacobian def get_analytical_jacobian(input, output): - input = contiguous(input) + diff_input_list = list(iter_tensors(input, True)) jacobian = make_jacobian(input, output.numel()) jacobian_reentrant = make_jacobian(input, output.numel()) grad_output = torch.zeros_like(output) @@ -102,9 +91,9 @@ def get_analytical_jacobian(input, output): flat_grad_output.zero_() flat_grad_output[i] = 1 for jacobian_c in (jacobian, jacobian_reentrant): - zero_gradients(input) - output.backward(grad_output, create_graph=True) - for jacobian_x, (d_x, x) in zip(jacobian_c, iter_tensors_with_grad(input)): + grads_input = torch.autograd.grad(output, diff_input_list, grad_output, + retain_graph=True, allow_unused=True) + for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list): if d_x is not None and d_x.size() != x.size(): correct_grad_sizes = False elif jacobian_x.numel() != 0: @@ -177,10 +166,10 @@ def fail_test(msg): continue def fn(input): - return _as_tuple(func(*input))[i].data + return _as_tuple(func(*input))[i] analytical, reentrant, correct_grad_sizes = get_analytical_jacobian(tupled_inputs, o) - numerical = get_numerical_jacobian(fn, inputs, inputs, eps) + numerical = get_numerical_jacobian(fn, inputs, eps=eps) if not correct_grad_sizes: return fail_test('Analytical gradient has incorrect size') @@ -197,21 +186,21 @@ def fn(input): 'although analytical gradient matches numerical gradient') # check if the backward multiplies by grad_output - zero_gradients(inputs) output = _differentiable_outputs(func(*inputs)) if any([o.requires_grad for o in output]): - torch.autograd.backward(output, [torch.zeros_like(o) for o in output], create_graph=True) - var_inputs = list(filter(lambda i: isinstance(i, torch.Tensor), inputs)) - if not var_inputs: - raise RuntimeError("no Tensors found in input") - for i in var_inputs: - if i.grad is None: + diff_input_list = list(iter_tensors(inputs, True)) + if not diff_input_list: + raise RuntimeError("no Tensors requiring grad found in input") + grads_input = torch.autograd.grad(output, diff_input_list, [torch.zeros_like(o) for o in output], + allow_unused=True) + for gi, i in zip(grads_input, diff_input_list): + if gi is None: continue - if not i.grad.data.eq(0).all(): + if not gi.eq(0).all(): return fail_test('backward not multiplied by grad_output') - if i.grad.type() != i.type(): + if gi.type() != i.type(): return fail_test("grad is incorrect type") - if i.grad.size() != i.size(): + if gi.size() != i.size(): return fail_test('grad is incorrect size') return True @@ -254,17 +243,19 @@ def gradgradcheck(func, inputs, grad_outputs=None, eps=1e-6, atol=1e-5, rtol=1e- # If grad_outputs is not specified, create random Tensors of the same # shape, type, and device as the outputs def randn_like(x): - var = torch.testing.randn_like(x if x.is_floating_point() else x.double()) + y = torch.testing.randn_like(x if x.is_floating_point() else x.double()) if gen_non_contig_grad_outputs: - var = torch.testing.make_non_contiguous(var) - var.requires_grad = True - return var + y = torch.testing.make_non_contiguous(y) + return y.requires_grad_() outputs = _as_tuple(func(*inputs)) grad_outputs_gen = (randn_like(x) for x in outputs) grad_outputs = list(grad_outputs_gen) if not isinstance(inputs, tuple) else tuple(grad_outputs_gen) - def new_func(*input_args): - input_args = input_args[:-len(grad_outputs)] + num_outputs = len(grad_outputs) + + def new_func(*args): + input_args = args[:-num_outputs] + grad_outputs = args[-num_outputs:] outputs = _differentiable_outputs(func(*input_args)) input_args = tuple(x for x in input_args if isinstance(x, torch.Tensor) and x.requires_grad) grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)