diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 27c2566805c5..88c78a836a10 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -335,6 +335,42 @@ Tensor& sum_out(Tensor& result, const Tensor& self, IntArrayRef dim, ScalarType return at::native::sum_out(result, self, dim, false, dtype); } +int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) { + int64_t size = 1; + if (sizes.size() == 0) { + return 1; + } + for (auto d : dim) { + d = at::maybe_wrap_dim(d, sizes.size()); + size *= sizes[d]; + } + return size; +} + +Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) { + auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); + Tensor res = t; + for (size_t i = 0; i < n_dims; i++){ + if (dims_to_unsqueeze[i]) { + res = res.unsqueeze(i); + } + } + return res; +} + +Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) { + if (!keepdim && sizes.size() > 0) { + if (dims.size()==1) { + return grad.unsqueeze(dims[0]).expand(sizes); + } else { + Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); + return res.expand(sizes); + } + } else { + return grad.expand(sizes); + } +} + Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) { return at::native::prod_out( result, self, dim, keepdim, c10::optional(dtype)); @@ -416,6 +452,16 @@ Tensor logsumexp(const Tensor &self, IntArrayRef dims, bool keepdim) { return at::native::logsumexp_out(result, self, dims, keepdim); } +Tensor logsumexp_backward(const Tensor& grad, const Tensor & self, const Tensor& res, IntArrayRef dim, bool keepdim) { + Tensor grad_input = grad; + Tensor fwd_res = res; + if (!keepdim && self.dim() != 0) { + grad_input = unsqueeze_multiple(grad, dim, self.sizes().size()); + fwd_res = unsqueeze_multiple(res, dim, self.sizes().size()); + } + return grad_input * (self - fwd_res).exp(); +} + static Tensor& norm_out(Tensor &result, const Tensor &self, optional opt_p, IntArrayRef dim, bool keepdim, optional opt_dtype) { auto p = opt_p.value_or(2.0); @@ -628,6 +674,21 @@ Tensor &var_out(Tensor &result, const Tensor &self, IntArrayRef dim, bool unbias return std_var_out(result, self, dim, unbiased, keepdim, false); } +Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) { + return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean()); +} + +Tensor var_backward(const Tensor & grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { + if (self.dim() == 0) { + return at::var_backward(grad, self, unbiased); + } + Tensor unsqueezed_grad = grad; + if (!keepdim && self.dim() > 1) { + unsqueezed_grad = unsqueeze_multiple(grad, dim, self.sizes().size()); + } + return (2.0 / (at::_safe_size(self.sizes(), dim) - unbiased)) * unsqueezed_grad * (self - self.mean(dim, true)); +} + Tensor std(const Tensor& self, bool unbiased) { AT_CHECK(self.type().backend() == Backend::CPU || self.type().backend() == Backend::CUDA, "std only supports CPU AND CUDA backend, got: ", toString(self.type().backend())); diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index 0f58f4f76c75..04ee3dd7364e 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -34,6 +34,14 @@ namespace at { namespace native { DEFINE_DISPATCH(max_kernel); DEFINE_DISPATCH(min_kernel); +Tensor index_select_backward(const Tensor& grad, int64_t dim, const Tensor& indices, IntArrayRef sizes, bool keepdim) { + Tensor res = at::zeros(sizes, grad.options()); + if (!keepdim && sizes.size() > 0) { + return res.scatter_(dim, indices.unsqueeze(dim), grad.unsqueeze(dim)); + } + return res.scatter_(dim, indices, grad); +} + bool allclose(const Tensor& self, const Tensor& other, double rtol, double atol, bool equal_nan) { return at::isclose(self, other, rtol, atol, equal_nan).all().item(); } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 2e7b5e6e1ed9..bddf40a72fd2 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -384,6 +384,16 @@ Tensor permute(const Tensor& self, IntArrayRef dims) { return self.as_strided(newSizes, newStrides); } +Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { + // invert the permutation + auto ndims = fwd_dims.size(); + std::vector dims(ndims); + for (size_t i = 0; i < ndims; i++) { + dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; + } + return grad.permute(dims); +} + Tensor repeat(const Tensor& self, IntArrayRef repeats) { AT_CHECK(repeats.size() >= (size_t)self.dim(), "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor"); @@ -451,6 +461,12 @@ Tensor select(const Tensor& self, int64_t dim, int64_t index) { return self.as_strided(sizes, strides, storage_offset); } +Tensor select_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.select(dim, index).copy_(grad); + return grad_input; +} + Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_t step) { int64_t ndim = self.dim(); if (ndim == 0) { @@ -484,6 +500,12 @@ Tensor slice(const Tensor& self, int64_t dim, int64_t start, int64_t end, int64_ return self.as_strided(sizes, strides, storage_offset); } +Tensor slice_backward(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + auto grad_input = at::zeros(input_sizes, grad.options()); + grad_input.slice(dim, start, end, step).copy_(grad); + return grad_input; +} + std::vector split(const Tensor& self, int64_t split_size, int64_t dim) { AT_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor"); AT_CHECK(split_size >= 0, "split expects split_size be non-negative, but got split_size=", split_size); @@ -690,6 +712,28 @@ Tensor squeeze(const Tensor& self) { return self.as_strided(std::get<0>(g), std::get<1>(g)); } +Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { + auto result = self; + + int64_t nDims = sizes.size(); + for (int64_t dim = 0; dim < nDims; dim++) { + if (sizes[dim] == 1) { + result = result.unsqueeze(dim); + } + } + return result; +} + +Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) { + dim = at::maybe_wrap_dim(dim, sizes.size()); + // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided + // unsqueezing in the backward. + if (sizes.size() > 0 && sizes[dim] == 1) { + return self.unsqueeze(dim); + } + return self; +} + Tensor squeeze(const Tensor& self, int64_t dim) { int64_t dims = self.dim(); dim = maybe_wrap_dim(dim, dims); diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7f2d1a2be3ff..386b60803e10 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -59,6 +59,10 @@ dispatch: CUDA: _cudnn_init_dropout_state +- func: index_select_backward(Tensor grad, int64_t dim, Tensor indices, int[] sizes, bool keepdim) -> Tensor + +- func: select_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t index) -> Tensor + - func: _fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) matches_jit_signature: True variants: function @@ -1320,6 +1324,8 @@ - func: logsumexp(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True +- func: logsumexp_backward(Tensor grad, Tensor self, Tensor res, int[1] dim, bool keepdim) -> Tensor + - func: margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor matches_jit_signature: True @@ -1392,6 +1398,8 @@ - func: mean(Tensor self, int[1] dim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True +- func: sum_backward(Tensor grad, int[] sizes, int[] dims, bool keepdim) -> Tensor + - func: median(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) matches_jit_signature: True variants: function, method @@ -1611,6 +1619,8 @@ matches_jit_signature: True variants: method # This is method-only to match the previous tensor API. In the future we could make this a function too. +- func: permute_backwards(Tensor grad, int[] fwd_dims) -> Tensor + - func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor matches_jit_signature: True @@ -1885,11 +1895,15 @@ variants: function, method device_guard: False +- func: _safe_size(int[] sizes, int[] dim) -> int64_t + - func: slice(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> Tensor(a) matches_jit_signature: True variants: function, method device_guard: False +- func: slice_backward(Tensor grad, int[] input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) -> Tensor + - func: slogdet(Tensor self) -> (Tensor, Tensor) matches_jit_signature: True variants: function, method @@ -1980,6 +1994,10 @@ variants: function, method device_guard: False +- func: unsqueeze_to(Tensor self, int[] sizes) -> Tensor + +- func: unsqueeze_to(Tensor self, int64_t dim, int[] sizes) -> Tensor + - func: squeeze_(Tensor(a!) self) -> Tensor(a!) matches_jit_signature: True variants: method @@ -2276,6 +2294,10 @@ - func: var(Tensor self, int[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) matches_jit_signature: True +- func: var_backward(Tensor grad, Tensor self, bool unbiased) -> Tensor + +- func: var_backward(Tensor grad, Tensor self, int[] dim, bool unbiased, bool keepdim) -> Tensor + - func: view_as(Tensor self, Tensor other) -> Tensor matches_jit_signature: True variants: method diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 6f151142fc75..0c8be4ced4e8 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -209,6 +209,7 @@ def method_tests(): ('expand', (1, S), (1, 1, S), 'new_dim_front_old_front_1'), ('expand', (), (dont_convert(()),), 'scalar_to_scalar'), ('expand', (), (1, 3, 2), 'scalar_to_dims'), + ('expand_as', (S, 1, 1), (torch.rand(S, S, S),)), ('exp', (S, S, S), NO_ARGS), ('exp', (), NO_ARGS, 'scalar'), ('expm1', (S, S, S), NO_ARGS), @@ -1020,6 +1021,8 @@ def unpack_variables(args): 'test_det_dim2_null', 'test_det_rank1', 'test_det_rank2', + # `other` expand_as(self, other) is not used in autograd. + 'test_expand_as', 'test_logdet', 'test_logdet_1x1', 'test_logdet_symmetric', diff --git a/test/cpp/jit/test_misc.h b/test/cpp/jit/test_misc.h index 9039dc7b7e3c..6f504691ded8 100644 --- a/test/cpp/jit/test_misc.h +++ b/test/cpp/jit/test_misc.h @@ -848,7 +848,24 @@ void testDifferentiate(std::ostream& out = std::cout) { auto grad_spec = differentiate(graph); std::vector expected_captured_inputs = {0, 1}; - std::vector expected_captured_outputs = {1, 2}; + // With add/mul implemented using torchscript, we passes sizes of + // self & other instead passing the tensors themselve. + // The forward graph is now + //graph(%0 : Float(2, 3, 4) + // %1 : Float(2, 3, 4)) { + // %2 : Float(2, 3, 4) = aten::mul(%0, %1) + // %self_size.4 : int[] = aten::size(%0) + // %other_size.4 : int[] = aten::size(%1) + // %3 : Float(2, 3, 4) = aten::mul(%2, %0) + // %self_size.2 : int[] = aten::size(%2) + // %4 : int = prim::Constant[value=1]() + // %7 : int[] = aten::size(%3) + // %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) + // return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7); + //} + // Thus all the sizes info added in forward outputs are saved + // in grad_spec.df_input_caputered_outputs. + std::vector expected_captured_outputs = {1, 2, 3, 4, 5}; std::vector expected_input_vjps = {0, 1}; std::vector expected_output_vjps = {0, 1}; ASSERT_EQ(grad_spec.f_real_outputs, 1); @@ -880,12 +897,29 @@ void testDifferentiateWithRequiresGrad(std::ostream& out = std::cout) { PropagateInputShapes(graph); PropagateRequiresGrad(graph); + // With add/mul implemented using torchscript, we passes sizes of + // self & other instead passing the tensors themselve. + // The forward graph is now + // graph(%0 : Float(*) + // %1 : Float(*)) { + // %2 : Float(*) = aten::mul(%1, %1) + // %3 : int = prim::Constant[value=1]() + // %4 : Float(*) = aten::add(%2, %1, %3) + // %39 : int[] = aten::size(%0) + // %6 : Float(*) = aten::add(%4, %0, %3) + // %7 : Float(*) = aten::mul(%6, %0) + // %self_size.2 : int[] = aten::size(%6) + // %11 : int[] = aten::size(%7) + // %9 : Float(*) = aten::add(%7, %1, %3) + // return (%4, %9, %39, %6, %self_size.2, %11); + // } + auto grad_spec = differentiate(graph); - std::vector expected_input_vjps = {1, 2}; // for e and %4 = (d + a) + std::vector expected_input_vjps = {1, 3}; // for e and %6 = (d + a) std::vector expected_output_vjps = {0}; // only a requires grad ASSERT_EQ(grad_spec.f_real_outputs, 2); ASSERT_EQ(grad_spec.df_input_captured_inputs, std::vector({0})); - ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3})); + ASSERT_EQ(grad_spec.df_input_captured_outputs, std::vector({2, 3, 4, 5})); ASSERT_EQ(grad_spec.df_input_vjps, expected_input_vjps); ASSERT_EQ(grad_spec.df_output_vjps, expected_output_vjps); out << "testDifferentiateWithRequiresGrad\n"; diff --git a/test/expect/TestFuser.test_lstm_cuda-backward.expect b/test/expect/TestFuser.test_lstm_cuda-backward.expect index 8c43a7eaf251..dda0c4f892da 100644 --- a/test/expect/TestFuser.test_lstm_cuda-backward.expect +++ b/test/expect/TestFuser.test_lstm_cuda-backward.expect @@ -22,20 +22,20 @@ graph(%0 : Float(*, *), %forgetgate : Float(*, *), %cellgate : Float(*, *), %outgate : Float(*, *), - %24 : int[], - %25 : int[], - %26 : Float(*, *)): - %27 : int = prim::Constant[value=1]() - %28 : int[] = aten::size(%outgate) - %29 : int[] = aten::size(%26) - %30 : int[] = aten::size(%ingate) - %31 : int[] = aten::size(%cellgate) - %32 : int[] = aten::size(%forgetgate) - %33 : int[] = aten::size(%9) - %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %26, %28) - %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %26, %0, %outgate, %33, %32, %24, %31, %30, %25, %29) + %self_size.5 : int[], + %other_size.5 : int[], + %self_size.3 : int[], + %other_size.3 : int[], + %28 : int[], + %29 : int[], + %30 : Float(*, *), + %self_size.1 : int[], + %other_size.1 : int[]): + %33 : int = prim::Constant[value=1]() + %34 : Tensor = prim::FusionGroup_0(%outgate, %0, %30, %self_size.1) + %grad_other.5 : Tensor, %36 : Tensor, %37 : Tensor, %38 : Tensor = prim::FusionGroup_1(%forgetgate, %9, %ingate, %cellgate, %1, %30, %0, %outgate, %other_size.5, %self_size.5, %28, %other_size.3, %self_size.3, %29, %other_size.1) %39 : Tensor[] = prim::ListConstruct(%38, %36, %37, %34) - %40 : Tensor = aten::cat(%39, %27) + %40 : Tensor = aten::cat(%39, %33) %41 : Tensor = aten::_grad_sum_to_size(%40, %19) %42 : Tensor = aten::_grad_sum_to_size(%40, %17) %43 : Tensor = aten::_grad_sum_to_size(%40, %14) @@ -44,13 +44,13 @@ graph(%0 : Float(*, *), %46 : Float(*, *) = aten::mm(%44, %45) %47 : Float(*, *) = aten::t(%10) %48 : Float(*, *) = aten::mm(%47, %44) - %49 : Float(*, *) = aten::t(%48) + %grad_self.7 : Float(*, *) = aten::t(%48) %50 : Float(*, *) = aten::t(%12) %51 : Float(*, *) = aten::mm(%43, %50) %52 : Float(*, *) = aten::t(%11) %53 : Float(*, *) = aten::mm(%52, %43) - %54 : Float(*, *) = aten::t(%53) - return (%grad_other.5, %41, %42, %46, %49, %51, %54) + %grad_self.9 : Float(*, *) = aten::t(%53) + return (%grad_other.5, %41, %42, %46, %grad_self.7, %51, %grad_self.9) with prim::FusionGroup_0 = graph(%0 : Float(*, *), %1 : Float(*, *), %2 : Float(*, *), diff --git a/test/expect/TestFuser.test_lstm_cuda-forward.expect b/test/expect/TestFuser.test_lstm_cuda-forward.expect index e53deb8cf844..414a0413c636 100644 --- a/test/expect/TestFuser.test_lstm_cuda-forward.expect +++ b/test/expect/TestFuser.test_lstm_cuda-forward.expect @@ -28,14 +28,16 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *), %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor = prim::ListUnpack(%16) %21 : int[] = prim::BroadcastSizes(%11, %12) %22 : int[] = prim::BroadcastSizes(%21, %13) - %hy : Float(*, *), %24 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) - %30 : int[] = aten::size(%0) - %31 : int[] = aten::size(%cellgate.1) - %32 : int[] = aten::size(%forgetgate.1) - %33 : int[] = aten::size(%ingate.1) - %34 : int[] = prim::BroadcastSizes(%32, %30) - %35 : int[] = prim::BroadcastSizes(%33, %31) - return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %35, %24) + %other_size.6 : int[] = aten::size(%0) + %hy : Float(*, *), %25 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %20, %19, %18, %17) + %31 : int[] = aten::size(%25) + %32 : int[] = aten::size(%outgate.1) + %33 : int[] = aten::size(%cellgate.1) + %34 : int[] = aten::size(%forgetgate.1) + %35 : int[] = aten::size(%ingate.1) + %36 : int[] = prim::BroadcastSizes(%34, %other_size.6) + %37 : int[] = prim::BroadcastSizes(%35, %33) + return (%hy, %cy, %7, %9, %11, %12, %21, %13, %22, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %34, %other_size.6, %35, %33, %36, %37, %25, %32, %31) with prim::FusionGroup_0 = graph(%0 : Float(*, *), %1 : Tensor, %2 : Tensor, diff --git a/test/expect/TestFuser.test_milstm_cuda-backward.expect b/test/expect/TestFuser.test_milstm_cuda-backward.expect index eb8e61270c97..590d427323ea 100644 --- a/test/expect/TestFuser.test_milstm_cuda-backward.expect +++ b/test/expect/TestFuser.test_milstm_cuda-backward.expect @@ -17,50 +17,50 @@ graph(%0 : Float(*, *), %Wx : Float(*, *), %Uz : Float(*, *), %18 : Float(*, *), - %19 : int[], - %20 : int[], - %21 : int[], - %22 : int[], - %23 : int[], + %self_size.13 : int[], + %other_size.13 : int[], + %self_size.11 : int[], + %other_size.11 : int[], + %self_size.9 : int[], %24 : int[], + %25 : int[], + %self_size.7 : int[], + %27 : int[], + %28 : int[], + %29 : int[], + %30 : int[], %ingate : Float(*, *), %forgetgate : Float(*, *), %cellgate : Float(*, *), %outgate : Float(*, *), - %29 : int[], - %30 : int[], - %31 : Float(*, *)): - %32 : int = prim::Constant[value=1]() - %33 : int[] = aten::size(%outgate) - %34 : int[] = aten::size(%31) - %35 : int[] = aten::size(%ingate) - %36 : int[] = aten::size(%cellgate) - %37 : int[] = aten::size(%forgetgate) - %38 : Tensor = prim::FusionGroup_0(%outgate, %0, %31, %33) - %39 : Tensor, %40 : Tensor, %41 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %31, %0, %outgate, %forgetgate, %37, %29, %36, %35, %30, %34) - %42 : Tensor[] = prim::ListConstruct(%41, %39, %40, %38) - %43 : Tensor = aten::cat(%42, %32) - %44 : Tensor = aten::_grad_sum_to_size(%43, %24) - %45 : Tensor = aten::_grad_sum_to_size(%43, %22) - %46 : int[] = aten::size(%11) - %grad_self.7 : Tensor = prim::FusionGroup_2(%45, %Uz, %46) - %48 : int[] = aten::size(%Uz) - %49 : Tensor = aten::_grad_sum_to_size(%43, %19) - %50 : Tensor = aten::_grad_sum_to_size(%43, %20) - %51 : int[] = aten::size(%12) - %grad_self.9 : Tensor = prim::FusionGroup_3(%50, %Wx, %51) - %53 : int[] = aten::size(%Wx) - %54 : int[] = aten::size(%18) - %55 : Tensor = prim::FusionGroup_4(%49, %18, %45, %11, %48) - %56 : int[] = aten::size(%13) - %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %49, %Uz, %50, %12, %56, %53, %54) + %self_size.5 : int[], + %self_size.3 : int[], + %other_size.3 : int[], + %38 : int[], + %39 : int[], + %40 : Float(*, *), + %self_size.1 : int[], + %other_size.1 : int[]): + %43 : int = prim::Constant[value=1]() + %44 : Tensor = prim::FusionGroup_0(%outgate, %0, %40, %self_size.1) + %45 : Tensor, %46 : Tensor, %47 : Tensor = prim::FusionGroup_1(%10, %ingate, %cellgate, %1, %40, %0, %outgate, %forgetgate, %self_size.5, %38, %other_size.3, %self_size.3, %39, %other_size.1) + %48 : Tensor[] = prim::ListConstruct(%47, %45, %46, %44) + %49 : Tensor = aten::cat(%48, %43) + %50 : Tensor = aten::_grad_sum_to_size(%49, %30) + %51 : Tensor = aten::_grad_sum_to_size(%49, %28) + %grad_self.7 : Tensor = prim::FusionGroup_2(%51, %Uz, %self_size.7) + %53 : Tensor = aten::_grad_sum_to_size(%49, %24) + %54 : Tensor = aten::_grad_sum_to_size(%49, %25) + %grad_self.9 : Tensor = prim::FusionGroup_3(%54, %Wx, %self_size.9) + %56 : Tensor = prim::FusionGroup_4(%53, %18, %51, %11, %other_size.11) + %grad_self.13 : Tensor, %58 : Tensor = prim::FusionGroup_5(%Wx, %13, %53, %Uz, %54, %12, %self_size.13, %other_size.13, %self_size.11) %59 : Float(*, *) = aten::t(%14) - %60 : Float(*, *) = aten::mm(%59, %55) - %61 : Float(*, *) = aten::t(%60) + %60 : Float(*, *) = aten::mm(%59, %56) + %grad_self.15 : Float(*, *) = aten::t(%60) %62 : Float(*, *) = aten::t(%15) %63 : Float(*, *) = aten::mm(%62, %58) - %64 : Float(*, *) = aten::t(%63) - return (%44, %grad_self.7, %grad_self.9, %grad_self.13, %61, %64) + %grad_self.17 : Float(*, *) = aten::t(%63) + return (%50, %grad_self.7, %grad_self.9, %grad_self.13, %grad_self.15, %grad_self.17) with prim::FusionGroup_0 = graph(%0 : Float(*, *), %1 : Float(*, *), %2 : Float(*, *), diff --git a/test/expect/TestFuser.test_milstm_cuda-forward.expect b/test/expect/TestFuser.test_milstm_cuda-forward.expect index a8b3b77a6a2b..019db7d19330 100644 --- a/test/expect/TestFuser.test_milstm_cuda-forward.expect +++ b/test/expect/TestFuser.test_milstm_cuda-forward.expect @@ -24,28 +24,31 @@ with prim::DifferentiableGraph_0 = graph(%0 : Float(*, *), %11 : Float(*, *) = aten::t(%6) %Uz.1 : Float(*, *) = aten::mm(%5, %11) %13 : Float(*, *) = aten::mul(%4, %Wx.1) - %14 : int[] = aten::size(%1) - %15 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1) - %16 : Tensor[] = aten::broadcast_tensors(%15) - %17 : Tensor, %18 : Tensor, %19 : Tensor, %20 : Tensor, %21 : Tensor, %22 : Tensor = prim::ListUnpack(%16) - %23 : int[] = aten::size(%3) - %24 : int[] = aten::size(%Wx.1) - %25 : int[] = prim::BroadcastSizes(%23, %24) - %26 : int[] = aten::size(%13) - %27 : int[] = aten::size(%Uz.1) - %28 : int[] = prim::BroadcastSizes(%26, %27) - %29 : int[] = aten::size(%2) - %30 : int[] = prim::BroadcastSizes(%29, %27) - %31 : int[] = prim::BroadcastSizes(%28, %25) - %32 : int[] = prim::BroadcastSizes(%31, %30) - %hy : Float(*, *), %34 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %22, %21, %20, %19, %18, %17) - %40 : int[] = aten::size(%0) - %41 : int[] = aten::size(%cellgate.1) - %42 : int[] = aten::size(%forgetgate.1) - %43 : int[] = aten::size(%ingate.1) - %44 : int[] = prim::BroadcastSizes(%42, %40) - %45 : int[] = prim::BroadcastSizes(%43, %41) - return (%hy, %cy, %Wx.1, %Uz.1, %13, %28, %25, %31, %30, %32, %14, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %44, %45, %34) + %self_size.14 : int[] = aten::size(%4) + %other_size.14 : int[] = aten::size(%Wx.1) + %self_size.12 : int[] = aten::size(%13) + %other_size.12 : int[] = aten::size(%Uz.1) + %self_size.10 : int[] = aten::size(%3) + %self_size.8 : int[] = aten::size(%2) + %20 : int[] = aten::size(%1) + %21 : Tensor[] = prim::ListConstruct(%1, %2, %Uz.1, %13, %3, %Wx.1) + %22 : Tensor[] = aten::broadcast_tensors(%21) + %23 : Tensor, %24 : Tensor, %25 : Tensor, %26 : Tensor, %27 : Tensor, %28 : Tensor = prim::ListUnpack(%22) + %29 : int[] = prim::BroadcastSizes(%self_size.10, %other_size.14) + %30 : int[] = prim::BroadcastSizes(%self_size.12, %other_size.12) + %31 : int[] = prim::BroadcastSizes(%self_size.8, %other_size.12) + %32 : int[] = prim::BroadcastSizes(%30, %29) + %33 : int[] = prim::BroadcastSizes(%32, %31) + %hy : Float(*, *), %35 : Float(*, *), %cy : Float(*, *), %outgate.1 : Float(*, *), %cellgate.1 : Float(*, *), %forgetgate.1 : Float(*, *), %ingate.1 : Float(*, *) = prim::FusionGroup_0(%0, %28, %27, %26, %25, %24, %23) + %41 : int[] = aten::size(%0) + %42 : int[] = aten::size(%35) + %43 : int[] = aten::size(%outgate.1) + %44 : int[] = aten::size(%cellgate.1) + %45 : int[] = aten::size(%forgetgate.1) + %46 : int[] = aten::size(%ingate.1) + %47 : int[] = prim::BroadcastSizes(%45, %41) + %48 : int[] = prim::BroadcastSizes(%46, %44) + return (%hy, %cy, %Wx.1, %Uz.1, %13, %self_size.14, %other_size.14, %self_size.12, %other_size.12, %self_size.10, %30, %29, %self_size.8, %32, %31, %33, %20, %ingate.1, %forgetgate.1, %cellgate.1, %outgate.1, %45, %46, %44, %47, %48, %35, %43, %42) with prim::FusionGroup_0 = graph(%0 : Float(*, *), %1 : Tensor, %2 : Tensor, diff --git a/test/expect/TestJit.test_cpp_cuda.expect b/test/expect/TestJit.test_cpp_cuda.expect index e15ab3c6fe28..5e465c3766b4 100644 --- a/test/expect/TestJit.test_cpp_cuda.expect +++ b/test/expect/TestJit.test_cpp_cuda.expect @@ -85,47 +85,48 @@ testDifferentiate graph(%0 : Float(2, 3, 4), %1 : Float(2, 3, 4)): %2 : Float(2, 3, 4) = aten::mul(%0, %1) + %self_size.4 : int[] = aten::size(%0) + %other_size.4 : int[] = aten::size(%1) %3 : Float(2, 3, 4) = aten::mul(%2, %0) + %self_size.2 : int[] = aten::size(%2) %4 : int = prim::Constant[value=1]() %7 : int[] = aten::size(%3) %5 : Float(2, 3, 4) = aten::add(%3, %1, %4) - return (%5, %2, %7) + return (%5, %2, %self_size.4, %other_size.4, %self_size.2, %7) graph(%0 : Float(2, 3, 4), %1 : Float(2, 3, 4), %2 : Float(2, 3, 4), %3 : Float(2, 3, 4), %4 : Float(2, 3, 4), - %5 : int[]): - %7 : int = prim::Constant[value=1]() - %6 : int[] = aten::size(%3) - %8 : Tensor, %9 : Tensor = prim::GradOf[name="aten::add"](%0) + %self_size.3 : int[], + %other_size.3 : int[], + %self_size.1 : int[], + %8 : int[]): + %9 : int = prim::Constant[value=1]() + %10 : Tensor, %11 : Tensor = prim::GradOf[name="aten::add"](%0) block0(): - %10 : Tensor = aten::_grad_sum_to_size(%0, %5) - %11 : Float(2, 3, 4) = aten::mul(%0, %7) - %12 : Tensor = aten::_grad_sum_to_size(%11, %6) - -> (%10, %12) - %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%8) + %12 : Tensor = aten::_grad_sum_to_size(%0, %8) + %13 : Float(2, 3, 4) = aten::mul(%0, %9) + %14 : Tensor = aten::_grad_sum_to_size(%13, %other_size.3) + -> (%12, %14) + %grad_self.2 : Tensor, %grad_other.2 : Tensor = prim::GradOf[name="aten::mul"](%10) block0(): - %15 : Tensor = aten::mul(%8, %2) - %16 : int[] = aten::size(%4) - %grad_self.1 : Tensor = aten::_grad_sum_to_size(%15, %16) - %18 : Tensor = aten::mul(%8, %4) - %19 : int[] = aten::size(%2) - %grad_other.1 : Tensor = aten::_grad_sum_to_size(%18, %19) + %17 : Tensor = aten::mul(%10, %2) + %grad_self.1 : Tensor = aten::_grad_sum_to_size(%17, %self_size.1) + %19 : Tensor = aten::mul(%10, %4) + %grad_other.1 : Tensor = aten::_grad_sum_to_size(%19, %self_size.3) -> (%grad_self.1, %grad_other.1) %21 : Tensor = prim::AutogradAdd(%1, %grad_self.2) %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%21) block0(): %24 : Tensor = aten::mul(%21, %3) - %25 : int[] = aten::size(%2) - %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %25) - %27 : Tensor = aten::mul(%21, %2) - %28 : int[] = aten::size(%3) - %grad_other.3 : Tensor = aten::_grad_sum_to_size(%27, %28) + %grad_self.3 : Tensor = aten::_grad_sum_to_size(%24, %self_size.3) + %26 : Tensor = aten::mul(%21, %2) + %grad_other.3 : Tensor = aten::_grad_sum_to_size(%26, %other_size.3) -> (%grad_self.3, %grad_other.3) - %30 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self) - %31 : Tensor = prim::AutogradAdd(%9, %grad_other) - return (%30, %31) + %28 : Tensor = prim::AutogradAdd(%grad_other.2, %grad_self) + %29 : Tensor = prim::AutogradAdd(%11, %grad_other) + return (%28, %29) testDifferentiateWithRequiresGrad graph(%0 : Float(*), @@ -133,37 +134,38 @@ graph(%0 : Float(*), %2 : Float(*) = aten::mul(%1, %1) %3 : int = prim::Constant[value=1]() %4 : Float(*) = aten::add(%2, %1, %3) + %39 : int[] = aten::size(%0) %6 : Float(*) = aten::add(%4, %0, %3) %7 : Float(*) = aten::mul(%6, %0) + %self_size.2 : int[] = aten::size(%6) %11 : int[] = aten::size(%7) %9 : Float(*) = aten::add(%7, %1, %3) - return (%4, %9, %6, %11) + return (%4, %9, %39, %6, %self_size.2, %11) graph(%0 : Float(*), %1 : Float(*), %2 : Float(*), - %3 : Float(*), - %4 : int[]): - %6 : int = prim::Constant[value=1]() - %5 : int[] = aten::size(%2) - %7 : Tensor = prim::GradOf[name="aten::add"](%0) + %3 : int[], + %4 : Float(*), + %self_size.1 : int[], + %6 : int[]): + %7 : int = prim::Constant[value=1]() + %8 : Tensor = prim::GradOf[name="aten::add"](%0) block0(): - %8 : Tensor = aten::_grad_sum_to_size(%0, %4) - -> (%8) - %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%7) + %9 : Tensor = aten::_grad_sum_to_size(%0, %6) + -> (%9) + %grad_self : Tensor, %grad_other : Tensor = prim::GradOf[name="aten::mul"](%8) block0(): - %11 : Tensor = aten::mul(%7, %2) - %12 : int[] = aten::size(%3) - %grad_self.1 : Tensor = aten::_grad_sum_to_size(%11, %12) - %14 : Tensor = aten::mul(%7, %3) - %15 : int[] = aten::size(%2) - %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %15) + %12 : Tensor = aten::mul(%8, %2) + %grad_self.1 : Tensor = aten::_grad_sum_to_size(%12, %self_size.1) + %14 : Tensor = aten::mul(%8, %4) + %grad_other.1 : Tensor = aten::_grad_sum_to_size(%14, %3) -> (%grad_self.1, %grad_other.1) - %17 : Tensor = prim::AutogradAdd(%1, %grad_self) - %18 : Tensor = prim::GradOf[name="aten::add"](%17) + %16 : Tensor = prim::AutogradAdd(%1, %grad_self) + %17 : Tensor = prim::GradOf[name="aten::add"](%16) block0(): - %19 : Tensor = aten::mul(%17, %6) - %20 : Tensor = aten::_grad_sum_to_size(%19, %5) - -> (%20) - %21 : Tensor = prim::AutogradAdd(%grad_other, %18) - return (%21) + %18 : Tensor = aten::mul(%16, %7) + %19 : Tensor = aten::_grad_sum_to_size(%18, %3) + -> (%19) + %20 : Tensor = prim::AutogradAdd(%grad_other, %17) + return (%20) diff --git a/test/test_jit.py b/test/test_jit.py index c2b2238456b1..926501fdee86 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -4998,6 +4998,39 @@ def test_copy_behavior(t, non_blocking=False): self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype) self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device) + # Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor + t = torch.tensor(5).float().requires_grad_() + out_ref = t.to(torch.float32) + out = s(t, "t.to(torch.float32)") + self.assertEqual(out_ref, out) + + grad_ref = torch.autograd.grad(out_ref.sum(), t) + grad = torch.autograd.grad(out.sum(), t) + self.assertEqual(grad_ref, grad) + + # Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor + out_ref = t.to('cpu') + out = s(t, "t.to('cpu')") + self.assertEqual(out_ref, out) + + grad_ref = torch.autograd.grad(out_ref.sum(), t) + grad = torch.autograd.grad(out.sum(), t) + self.assertEqual(grad_ref, grad) + + # Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor + @torch.jit.script + def func2(t, t_ref): + return t.to(t_ref) + + func2.debug_disable_autodiff_subgraph_inlining() + + t_ref = torch.tensor(4).double() + out_ref = t.to(t_ref) + out = func2(t, t_ref) + grad_ref = torch.autograd.grad(out_ref.sum(), t) + grad = torch.autograd.grad(out.sum(), t) + self.assertEqual(grad_ref, grad) + @unittest.skipIf(not RUN_CUDA, "No CUDA") def test_tensor_number_math_cuda(self): self._test_tensor_number_math(device='cuda') @@ -10637,6 +10670,7 @@ def forward(self, x, y): DISABLE_AUTODIFF_SUBGRAPH_INLINING = { 'test_nn_avg_pool2d', 'test_nn_adaptive_avg_pool2d', + 'test_nn_embedding', 'test_nn_log_softmax', 'test_nn_threshold', 'test_nn_nll_loss', diff --git a/third_party/onnx b/third_party/onnx index 822d8df0a2a3..15c33c945851 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 822d8df0a2a32233c6022f50a158817a0f19bdc7 +Subproject commit 15c33c945851907411619f599900c3852108e7e3 diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 641b55d5ed55..be8ac7b64866 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -452,7 +452,7 @@ self: zeros_like(grad) - name: logsumexp(Tensor self, IntArrayRef dim, bool keepdim) - self: logsumexp_backward(grad, self, result, dim, keepdim) + self: at::logsumexp_backward(grad, self, result, dim, keepdim) - name: lt_(Tensor self, Scalar other) self: zeros_like(self) @@ -495,13 +495,13 @@ self: grad.expand(self.sizes()).to(self.type().scalarType()) / self.numel() - name: mean(Tensor self, IntArrayRef dim, bool keepdim) - self: sum_backward(grad, self.sizes(), dim, keepdim) / _safe_size(self.sizes(), dim) + self: sum_backward(grad, self.sizes(), dim, keepdim) / at::_safe_size(self.sizes(), dim) - name: mean(Tensor self, IntArrayRef dim, ScalarType dtype) - self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / _safe_size(self.sizes(), dim) + self: sum_backward(grad, self.sizes(), dim, false).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim) - name: mean(Tensor self, IntArrayRef dim, bool keepdim, ScalarType dtype) - self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / _safe_size(self.sizes(), dim) + self: sum_backward(grad, self.sizes(), dim, keepdim).to(self.type().scalarType()) / at::_safe_size(self.sizes(), dim) - name: median(Tensor self) self: select_equals_backward(grad, self, result) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index d44ec5af26de..63b44163f3d6 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -76,18 +76,6 @@ Tensor maybe_multiply(const Tensor & t, const Scalar & s) { } } -int64_t _safe_size(IntArrayRef sizes, IntArrayRef dim) { - int64_t size = 1; - if (sizes.size() == 0) { - return 1; - } - for (auto d : dim) { - d = at::maybe_wrap_dim(d, sizes.size()); - size *= sizes[d]; - } - return size; -} - Tensor norm_backward(const Tensor & grad, const Tensor & self, const optional & p_, const Tensor & norm) { double p = p_.value_or(2.0).toDouble(); Tensor self_scaled; @@ -160,40 +148,6 @@ Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) { return grad * args.digamma_().sum(-1); } -Tensor permute_backwards(const Tensor & grad, IntArrayRef fwd_dims) { - // invert the permutation - auto ndims = fwd_dims.size(); - std::vector dims(ndims); - for (size_t i = 0; i < ndims; i++) { - dims[at::maybe_wrap_dim(fwd_dims[i], ndims)] = i; - } - return grad.permute(dims); -} - -Tensor unsqueeze_multiple(const Tensor & t, IntArrayRef dim, size_t n_dims) { - auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims); - Tensor res = t; - for (size_t i = 0; i < n_dims; i++){ - if (dims_to_unsqueeze[i]) { - res = res.unsqueeze(i); - } - } - return res; -} - -Tensor sum_backward(const Tensor & grad, IntArrayRef sizes, IntArrayRef dims, bool keepdim) { - if (!keepdim && sizes.size() > 0) { - if (dims.size()==1) { - return grad.unsqueeze(dims[0]).expand(sizes); - } else { - Tensor res = unsqueeze_multiple(grad, dims, sizes.size()); - return res.expand(sizes); - } - } else { - return grad.expand(sizes); - } -} - std::vector reverse_list(const IntArrayRef list) { auto result = std::vector(); result.reserve(list.size()); @@ -429,13 +383,13 @@ Tensor cumsum_backward(const Tensor &x, int64_t dim, ScalarType input_dtype) { return cumsum_backward(x.to(input_dtype), dim); } -Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { - if (!keepdim && self.dim() != 0) { - grad = unsqueeze_multiple(grad, dim, self.sizes().size()); - result = unsqueeze_multiple(result, dim, self.sizes().size()); - } - return grad * (self - result).exp(); -} +//Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { +// if (!keepdim && self.dim() != 0) { +// grad = unsqueeze_multiple(grad, dim, self.sizes().size()); +// result = unsqueeze_multiple(result, dim, self.sizes().size()); +// } +// return grad * (self - result).exp(); +//} Tensor unbind_backward(const variable_list& grads, int64_t dim) { IntArrayRef sizes; @@ -454,28 +408,6 @@ Tensor unbind_backward(const variable_list& grads, int64_t dim) { return at::stack(grads_tensors, dim); } -Tensor unsqueeze_to(const Tensor & self, IntArrayRef sizes) { - auto result = self; - - int64_t nDims = sizes.size(); - for (int64_t dim = 0; dim < nDims; dim++) { - if (sizes[dim] == 1) { - result = result.unsqueeze(dim); - } - } - return result; -} - -Tensor unsqueeze_to(const Tensor & self, int64_t dim, IntArrayRef sizes) { - dim = at::maybe_wrap_dim(dim, sizes.size()); - // in NumPy it's not an error to unsqueeze a scalar, but we still need to avoided - // unsqueezing in the backward. - if (sizes.size() > 0 && sizes[dim] == 1) { - return self.unsqueeze(dim); - } - return self; -} - std::vector cat_tensors_backward(const Tensor & grad, const std::vector> &sizes, int64_t dim) { dim = at::legacy_cat_wrap_dim(dim, sizes); std::vector grad_inputs(sizes.size()); @@ -606,26 +538,6 @@ Tensor select_equals_backward(Tensor grad, const Tensor & input, const Tensor & return grad_input; } -Tensor index_select_backward(Tensor grad, int64_t dim, Tensor indices, IntArrayRef sizes, bool keepdim) { - if (!keepdim && sizes.size() > 0) { - grad = grad.unsqueeze(dim); - indices = indices.unsqueeze(dim); - } - return at::zeros(sizes, grad.options()).scatter_(dim, indices, grad); -} - -Tensor slice_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.slice(dim, start, end, step).copy_(grad); - return grad_input; -} - -Tensor select_backward(Tensor grad, IntArrayRef input_sizes, int64_t dim, int64_t index) { - auto grad_input = at::zeros(input_sizes, grad.options()); - grad_input.select(dim, index).copy_(grad); - return grad_input; -} - Tensor trace_backward(const Tensor & grad, IntArrayRef sizes) { if (sizes.size() != 2) { throw std::runtime_error("expected matrix input"); @@ -651,19 +563,6 @@ Tensor unfold_backward(const Tensor & grad, IntArrayRef input_sizes, int64_t dim return grad_input.view(input_sizes); } -Tensor var_backward(const Tensor & grad, const Tensor & self, bool unbiased) { - return (2.0 / (self.numel() - unbiased)) * grad * (self - self.mean()); -} - -Tensor var_backward(Tensor grad, const Tensor & self, IntArrayRef dim, bool unbiased, bool keepdim) { - if (self.dim() == 0) { - return var_backward(grad, self, unbiased); - } - if (!keepdim && self.dim() > 1) { - grad = unsqueeze_multiple(grad, dim, self.sizes().size()); - } - return (2.0 / (_safe_size(self.sizes(), dim) - unbiased)) * grad * (self - self.mean(dim, true)); -} Tensor masked_scatter_backward(const Tensor & grad, const Tensor & mask, IntArrayRef sizes) { int64_t numel = 1; diff --git a/torch/csrc/jit/autodiff.cpp b/torch/csrc/jit/autodiff.cpp index c43669d8d621..b63f0698cb1d 100644 --- a/torch/csrc/jit/autodiff.cpp +++ b/torch/csrc/jit/autodiff.cpp @@ -27,6 +27,19 @@ void wrapDim(int64_t& dim, const std::vector& sizes) { } } +// kthvalue returns (kthvalue, index of kthvalue), currently autodiff only +// supports at most one output that requires grad. Thus we need to remove +// the grad for index that doesn't require grad. +bool needTrimGrad(Node* n) { + static OperatorSet need_trim_grad_ops = { + "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)", + }; + if (need_trim_grad_ops.find(n)) { + return true; + } + return false; +} + bool isDifferentiable(Node* n) { // TODO: scalar-tensor ops should be canonicalized static OperatorSet differentiable_ops = { @@ -194,15 +207,20 @@ static c10::optional> build_script_grad( auto fw_graph = compiled_graphs->forward; new_outputs = inlineCallTo( *graph, *fw_graph, node->inputs(), /*unpack_outputs=*/true); - for (size_t i = 0; i < node->outputs().size(); ++i) { - new_outputs.at(i)->setType(node->outputs()[i]->type()); - new_outputs.at(i)->replaceAllUsesWith(node->outputs()[i]); + auto outputs = node->outputs(); + AT_ASSERT(new_outputs.size() == outputs.size() + 1); + for (size_t i = 0; i < outputs.size(); ++i) { + new_outputs.at(i)->setType(outputs[i]->type()); + outputs[i]->replaceAllUsesWith(new_outputs.at(i)); } } // Use backward graph to construct reverse_block auto bw_graph = compiled_graphs->backward; auto grad_vec = grads.vec(); + if (needTrimGrad(node)) { + grad_vec.erase(grad_vec.begin()+1, grad_vec.end()); + } auto it = grad_vec.begin(); grad_vec.insert(it, new_outputs.back()); ArrayRef grad(grad_vec); @@ -578,35 +596,6 @@ class GradientHelper { return {sizes.at(dim) > 1 ? grads.at(0) : grads.at(0).unsqueeze(dim), nullptr}; - } else if (node->matches( - "aten::cat(Tensor[] tensors, int dim) -> Tensor", - /*const_inputs=*/attr::dim)) { - int dim = *node->get(attr::dim); - auto tensor_inputs = inputs; - tensor_inputs.pop_back(); - const auto& first_sizes = tensor_inputs.at(0).sizes(); - const auto has_first_sizes = [&first_sizes](SymbolicVariable var) { - return var.sizes() == first_sizes; - }; - - // NB: this is a specialization for the common case where all inputs are - // of equal sizes. We can use a single split operation to handle that. - if (std::all_of( - tensor_inputs.begin(), tensor_inputs.end(), has_first_sizes)) { - auto tensor_grads = grads.at(0).chunk(tensor_inputs.size(), dim); - tensor_grads.emplace_back(nullptr); // for attr::dim - return tensor_grads; - } else { - size_t offset = 0; - auto grad = grads.at(0); - std::vector tensor_grads; - for (auto input : tensor_inputs) { - tensor_grads.push_back(grad.narrow(dim, offset, input.sizes()[dim])); - offset += input.sizes()[dim]; - } - tensor_grads.emplace_back(nullptr); // for attr::dim - return tensor_grads; - } } else if (comparison_ops.find(node)) { return {nullptr, nullptr}; @@ -775,6 +764,11 @@ static std::vector linearGradientForNode( Node* node, ArrayRef grad_values) { auto& graph = *node->owningGraph(); + + // FIXME: In case forward has multi outputs, we only support one requires grad + if (needTrimGrad(node)) { + grad_values = grad_values.at(0); + } auto linear = graph.insertNode(graph.create(prim::GradOf, {grad_values}, 0)); // to make reading gradient graphs easier, remember the name of the forward op linear->s_(attr::name, node->kind().toDisplayString()); diff --git a/torch/csrc/jit/graph_executor.cpp b/torch/csrc/jit/graph_executor.cpp index 2beafd4bf925..2332dd12ae71 100644 --- a/torch/csrc/jit/graph_executor.cpp +++ b/torch/csrc/jit/graph_executor.cpp @@ -108,7 +108,9 @@ struct DifferentiableGraphBackward : public autograd::Function { variable_list outputs; outputs.reserve(num_outputs()); for (size_t i = 0; i < num_outputs(); ++i) { - if (should_compute_output(i)) { + // Input grad can also be None even if it requires grad + // Example: `other` in expand_as(self, other) + if (should_compute_output(i) && !stack[i].isNone()) { auto output = std::move(stack[i]).toTensor(); const auto& edge = next_edge(i); if (output.defined()) { diff --git a/torch/csrc/jit/passes/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp index 4f43440f3334..5ebb16926fb0 100644 --- a/torch/csrc/jit/passes/graph_fuser.cpp +++ b/torch/csrc/jit/passes/graph_fuser.cpp @@ -1272,6 +1272,7 @@ bool trackSingleGradSumToSizeToOutputs( "aten::div(Tensor self, Scalar other) -> Tensor", "aten::neg(Tensor self) -> Tensor", "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", // add this used to be prim::AutogradAdd }}; diff --git a/torch/csrc/jit/symbolic_script.cpp b/torch/csrc/jit/symbolic_script.cpp index 327626d7e896..bfbf4aea86e3 100644 --- a/torch/csrc/jit/symbolic_script.cpp +++ b/torch/csrc/jit/symbolic_script.cpp @@ -6,13 +6,303 @@ namespace { std::mutex lock; const std::vector functions = { R"( + + def _dim_arange(like, + dim: int): + def backward(grad_output): + return None, None + + return torch._dim_arange(like, dim), backward + + def contiguous(self): + def backward(grad_output): + return None + + return self.contiguous(), backward + + def erf(self): + def backward(grad_output): + # Precomputed constant C = 2.0 / math.sqrt(math.pi) + C = 1.1283791670955126 + grad_self = C * torch.exp(- self.pow(2)) * grad_output + return grad_self + + return torch.erf(self), backward + + def expand(self, + size: List[int], + implicit: bool=False): + self_size = self.size() + def backward(grad_output): + grad_self = torch._grad_sum_to_size(grad_output, self_size) + return grad_self, None, None + + return torch.expand(self, size, implicit=implicit), backward + + def expand_as(self, other): + self_size = self.size() + def backward(grad_output): + grad_self = grad_output._grad_sum_to_size(self_size) + return grad_self, None + + return torch.expand_as(self, other), backward + + def full_like(self, + fill_value: float): + def backward(grad_output): + return None, None + + return torch.full_like(self, fill_value), backward + + def kthvalue(self, + k: int, + dim: int, + keepdim: bool): + result0, result1 = torch.kthvalue(self, k, dim, keepdim) + self_size = self.size() + def backward(grad_output): + grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, keepdim) + return grad_self, None, None, None + + return result0, result1, backward + + def logsumexp(self, + dim: List[int], + keepdim: bool): + result = torch.logsumexp(self, dim, keepdim) + self_dim = self.dim() + def backward(grad_output): + grad_self = torch.logsumexp_backward(grad_output, self, result, dim, keepdim) + return grad_self, None, None + + return result, backward + + def mean_0(self): + self_size = self.size() + self_numel = self.numel() + def backward(grad_output): + grad_self = grad_output.expand(self_size) / self_numel + return grad_self + + return torch.mean(self), backward + + def mean_1(self, + dim: List[int], + keepdim: bool): + self_size = self.size() + def backward(grad_output): + grad_self = torch.sum_backward(grad_output, self_size, dim, keepdim) / torch._safe_size(self_size, dim) + return grad_self, None, None + + return torch.mean(self, dim, keepdim), backward + def mul(self, other): + self_size = self.size() + other_size = other.size() def backward(grad_output): - grad_self = (grad_output * other)._grad_sum_to_size(self.size()) - grad_other = (grad_output * self)._grad_sum_to_size(other.size()) + grad_self = (grad_output * other)._grad_sum_to_size(self_size) + grad_other = (grad_output * self)._grad_sum_to_size(other_size) return grad_self, grad_other + return self * other, backward + def nonzero(self): + def backward(grad_output): + return None + + return torch.nonzero(self), backward + + def ones_like(self): + def backward(grad_output): + return None + + return torch.ones_like(self), backward + + def permute(self, + dims: List[int]): + def backward(grad_output): + grad_self = torch.permute_backwards(grad_output, dims) + return grad_self, None + + return torch.permute(self, dims), backward + + def pow_0(self, + exponent: float): + def backward(grad_output): + grad_self = torch.where(torch.tensor(exponent == 0.0), torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1)) + return grad_self, None + + return torch.pow(self, exponent), backward + + def pow_1(self, exponent): + self_size = self.size() + exponent_size = exponent.size() + def backward(grad_output): + grad_self = torch.where(exponent == 0.0, torch.zeros_like(self), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size) + grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size) + return grad_self, grad_exponent + + return torch.pow(self, exponent), backward + + def pow_2(self: float, + exponent): + def backward(grad_output): + grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(torch.tensor(self)) + return None, grad_exponent + + return torch.pow(self, exponent), backward + + def rsub_0(self, other, + alpha: float = 1.0): + self_size = self.size() + other_size = other.size() + def backward(grad_output): + grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size) + grad_other = (grad_output)._grad_sum_to_size(other_size) + return grad_self, grad_other, None + + return torch.rsub(self, other, alpha), backward + + def rsub_1(self, + other: float, + alpha: float = 1.0): + self_size = self.size() + def backward(grad_output): + grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size) + return grad_self, None, None + + return torch.rsub(self, other, alpha), backward + + def select(self, + dim: int, + index: int): + self_size = self.size() + def backward(grad_output): + grad_self = torch.select_backward(grad_output, self_size, dim, index) + return grad_self, None, None + + return torch.select(self, dim, index), backward + + def sqrt(self): + result = torch.sqrt(self) + def backward(grad_output): + grad_self = grad_output / (2 * result) + return grad_self + + return result, backward + + def squeeze_0(self): + self_size = self.size() + def backward(grad_output): + grad_self = torch.unsqueeze_to(grad_output, self_size) + return grad_self + + return torch.squeeze(self), backward + + def squeeze_1(self, + dim: int): + self_size = self.size() + def backward(grad_output): + grad_self = torch.unsqueeze_to(grad_output, dim, self_size) + return grad_self, None + + return torch.squeeze(self, dim), backward + + def t(self): + def backward(grad_output): + grad_self = torch.t(grad_output) + return grad_self + + return torch.t(self), backward + + def to_0(self, + device: Optional[Device], + dtype: Optional[int], + non_blocking: bool=False, + copy: bool=False): + self_device = self.device + self_dtype = self.dtype + if device is not None: + result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy) + else: + result = self.to(dtype, non_blocking=non_blocking, copy=copy) + def backward(grad_output): + grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy) + return grad_self, None, None, None, None + + return result, backward + + + def to_1(self, + dtype: int, + non_blocking: bool=False, + copy: bool=False): + self_dtype = self.dtype + def backward(grad_output): + grad_self = grad_output.to(self_dtype, non_blocking, copy) + return grad_self, None, None, None + + return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward + + def to_2(self, + other, + non_blocking: bool=False, + copy: bool=False): + def backward(grad_output): + grad_self = grad_output.to(self, non_blocking, copy) + return grad_self, None, None, None + + return self.to(other, non_blocking=non_blocking, copy=copy), backward + + def topk(self, + k, + dim: int = -1, + largest: bool = True, + sorted: bool = True): + result0, result1 = torch.topk(self, k, dim, largest, sorted) + self_size = self.size() + def backward(grad_output): + grad_self = torch.index_select_backward(grad_output, dim, result1, self_size, True) + return grad_self, None, None, None, None + + return result0, result1, backward + + def transpose(self, + dim0: int, + dim1: int): + def backward(grad_output): + grad_self = torch.transpose(grad_output, dim0, dim1) + return grad_self, None, None + + return torch.transpose(self, dim0, dim1), backward + + def var_0(self, + unbiased: bool=True): + def backward(grad_output): + grad_self = torch.var_backward(grad_output, self, unbiased) + return grad_self, None + + return torch.var(self, unbiased), backward + + def var_1(self, + dim: List[int], + unbiased: bool, + keepdim: bool): + def backward(grad_output): + grad_self = torch.var_backward(grad_output, self, dim, unbiased, keepdim) + return grad_self, None, None, None + + return torch.var(self, dim, unbiased, keepdim), backward + + def view(self, + size: List[int]): + self_size = self.size() + def backward(grad_output): + grad_self = grad_output.reshape(self_size) + return grad_self, None + + return torch.view(self, size), backward + def adaptive_avg_pool2d(self, output_size: List[int]): def backward(grad_output): @@ -20,6 +310,19 @@ const std::vector functions = { return grad_self, None return torch.adaptive_avg_pool2d(self, output_size), backward + + def embedding(weight, + indices, + padding_idx: int, + scale_grad_by_freq: bool, + sparse: bool): + weight_size_0 = weight.size()[0] + def backward(grad_output): + grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse) + return grad_weight, None, None, None, None + + return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward + )"}; std::unordered_map schema_to_graphs; @@ -51,6 +354,24 @@ Argument originalReturnType(const TupleTypePtr& tup) { return Argument("", TupleType::create(std::move(types))); } +// In torchscript AD formulas, we define {func_0, func_1, ...} as +// overloaded functions of `func`. +// Remove the suffix before adding the schema string to map +// schema_to_graphs. +std::string overloadedSchemaString(const FunctionSchema& schema) { + const auto& schema_name = schema.name(); + auto pos = schema_name.find_last_of('_'); + auto schema_name_suffix = schema_name.substr(pos + 1); + std::string schema_string = canonicalSchemaString(schema); + if (!schema_name_suffix.empty() + && schema_name_suffix.find_first_not_of("0123456789") == string::npos) { + schema_string.replace(schema_string.find(schema_name), + schema_name.length(), + schema_name.substr(0, pos)); + } + return schema_string; +} + void loadModule(const std::shared_ptr& module) { for (const auto& method_ : module->get_methods()) { const auto& method = method_.value(); @@ -90,8 +411,12 @@ void loadModule(const std::shared_ptr& module) { Symbol::aten(loaded_schema.name()), loaded_schema.arguments(), {originalReturnType(new_tuple->type()->expect())}); - std::string key = canonicalSchemaString(actual_schema); - schema_to_graphs[key] = std::move(pair); + + // modify canonical string for function overloading + // prefer not to modify the schema name + auto schema_string = overloadedSchemaString(actual_schema); + + schema_to_graphs[schema_string] = std::move(pair); } } @@ -114,6 +439,14 @@ c10::optional gradientInfoForSchema( return cache_it->second; } else { auto schema_str = canonicalSchemaString(schema); + // JIT doesn't support keyword only arguments. + // Remove ' *,' in schema before looking up + // TODO: #16921 properly support keyword only arguments in JIT. + auto n = schema_str.find("*, "); + if (n != std::string::npos) { + schema_str = schema_str.erase(n, 3); + } + auto sym_script_it = schema_to_graphs.find(schema_str); if (sym_script_it != schema_to_graphs.end()) { cached_gradient_pairs.emplace_hint( diff --git a/torch/nn/functional.py b/torch/nn/functional.py index 95945916e902..3cd11d2945fa 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2920,10 +2920,10 @@ def normalize(input, p=2, dim=1, eps=1e-12, out=None): operation won't be differentiable. """ if out is None: - denom = input.norm(p, dim, True).clamp(min=eps).expand_as(input) + denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input) ret = input / denom else: - denom = input.norm(p, dim, True).clamp_(min=eps).expand_as(input) + denom = input.norm(p, dim, True).clamp_min(eps).expand_as(input) ret = torch.div(input, denom, out=out) return ret