From 99a87b29d8a3898c6f357472f2b1941f3f591f31 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Fri, 25 Sep 2020 11:58:49 -0500 Subject: [PATCH 01/28] Added torch.kron Tests pass. The implementation is based on tensordot. --- aten/src/ATen/native/LinearAlgebra.cpp | 50 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 3 ++ docs/source/torch.rst | 1 + test/test_linalg.py | 10 +++++ 4 files changed, 64 insertions(+) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 9c3742c129de..c32a7e07e70d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1686,5 +1686,55 @@ Tensor chain_matmul(TensorList matrices) { } } +/* +Calculates the Kronecker product between two Tensors. +*/ +Tensor kron(const Tensor& a, const Tensor& b) { + // TODO: Rewrite the following comment. + /* + We can obtain the kron result using tensordot or einsum. + In einsum notation suppose we have a with dim 4 and b with dim 2 + the result of below tensordot is einsum 0123, 45 -> 012345. + To obtain the correct kron we need to permute and reshape the array. + The permutation rule is the following: going from right to left + take axes in turn to form the permutation + with our example the correct permutation is 012435 and + the kron shape is (shape_a[0], shape_a[1], shape_a[3]*shape_b[0], shape_a[4]*shape_b[1]) + */ + std::vector a_sizes = a.sizes().vec(); + std::vector b_sizes = b.sizes().vec(); + int64_t a_ndim = a.dim(); + int64_t b_ndim = b.dim(); + + std::vector kron_permutation(a_ndim+b_ndim); + std::iota(std::begin(kron_permutation), std::end(kron_permutation), 0); + for (int64_t i = 1; i < b_ndim; i+=2) { + std::swap(kron_permutation[a_ndim+b_ndim-i-1], kron_permutation[a_ndim+b_ndim-i-2]); + i++; + } + + int64_t res_ndim = a_ndim > b_ndim ? a_ndim : b_ndim; + std::vector res_shape(res_ndim); + for (int64_t i = 0; i < res_ndim; i++) { + if (a_ndim == b_ndim) { + res_shape[res_ndim-1-i] = a_sizes[a_ndim-1-i] * b_sizes[b_ndim-1-i]; + } + else if (i >= b_ndim) { + res_shape[res_ndim-1-i] = a_sizes[a_ndim-1-i]; + } + else { + res_shape[res_ndim-1-i] = b_sizes[b_ndim-1-i]; + } + } + + Tensor result = at::tensordot(a, b, {}, {}); + // Step 2: now permute result + result = result.permute(kron_permutation); + // Step 3: reshape + result = result.reshape(res_shape); + + return result; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 78b6d3330300..153375116132 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1944,6 +1944,9 @@ CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda +- func: kron(Tensor self, Tensor other) -> Tensor + variants: function + - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full variants: function, method diff --git a/docs/source/torch.rst b/docs/source/torch.rst index beab6c449df1..f2134479c6c3 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -453,6 +453,7 @@ Other Operations flip fliplr flipud + kron rot90 gcd histc diff --git a/test/test_linalg.py b/test/test_linalg.py index d3e1905e8d24..a8c0fdeed1cd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -181,6 +181,16 @@ def test_det(self, device, dtype): with self.assertRaises(RuntimeError): op(t) + @unittest.skipIf(not TEST_NUMPY, "NumPy not found") + @dtypes(torch.double) + def test_kron(self, device, dtype): + a = torch.rand((4,), dtype=dtype, device=device) + b = torch.rand((5,), dtype=dtype, device=device) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + actual = torch.kron(a, b) + self.assertEqual(actual, expected) + # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation @skipCUDAIfNoMagma From 90f0159efd0b773f73c98d52ecdc76d987f4de03 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 28 Sep 2020 09:20:15 -0500 Subject: [PATCH 02/28] Updated tests --- test/test_linalg.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index a8c0fdeed1cd..f89d04bbec81 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -184,12 +184,18 @@ def test_det(self, device, dtype): @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.double) def test_kron(self, device, dtype): - a = torch.rand((4,), dtype=dtype, device=device) - b = torch.rand((5,), dtype=dtype, device=device) - expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) - actual = torch.kron(a, b) - self.assertEqual(actual, expected) + def run_test_case(a_shape, b_shape): + a = torch.rand(a_shape, dtype=dtype, device=device) + b = torch.rand(b_shape, dtype=dtype, device=device) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + actual = torch.kron(a, b) + self.assertEqual(actual, expected) + + shapes = [(4,), (2, 2,), (1,2,3,), (1,2,3,3,)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + run_test_case(a_shape, b_shape) # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation From 98a2f49fa35bf7d71fa1a2b9c53193ad9a539740 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 28 Sep 2020 09:56:38 -0500 Subject: [PATCH 03/28] Now kron permutation and reshape is correct --- aten/src/ATen/native/LinearAlgebra.cpp | 55 ++++++++++++++++++-------- 1 file changed, 38 insertions(+), 17 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index c32a7e07e70d..5e8f7630186d 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1690,40 +1690,61 @@ Tensor chain_matmul(TensorList matrices) { Calculates the Kronecker product between two Tensors. */ Tensor kron(const Tensor& a, const Tensor& b) { - // TODO: Rewrite the following comment. /* We can obtain the kron result using tensordot or einsum. - In einsum notation suppose we have a with dim 4 and b with dim 2 - the result of below tensordot is einsum 0123, 45 -> 012345. + In einsum notation suppose we have `a` with dim 4 and `b` with dim 2 + the result of below tensordot is in einsum 0123, 45 -> 012345. To obtain the correct kron we need to permute and reshape the array. The permutation rule is the following: going from right to left take axes in turn to form the permutation with our example the correct permutation is 012435 and - the kron shape is (shape_a[0], shape_a[1], shape_a[3]*shape_b[0], shape_a[4]*shape_b[1]) + the kron shape is (shape_a[0], shape_a[1], shape_a[3]*shape_b[0], + shape_a[4]*shape_b[1]) */ std::vector a_sizes = a.sizes().vec(); std::vector b_sizes = b.sizes().vec(); int64_t a_ndim = a.dim(); int64_t b_ndim = b.dim(); + int64_t min_ndim = std::min(a_ndim, b_ndim); + int64_t ndim_diff = std::abs(a_ndim - b_ndim); - std::vector kron_permutation(a_ndim+b_ndim); - std::iota(std::begin(kron_permutation), std::end(kron_permutation), 0); - for (int64_t i = 1; i < b_ndim; i+=2) { - std::swap(kron_permutation[a_ndim+b_ndim-i-1], kron_permutation[a_ndim+b_ndim-i-2]); - i++; + std::vector kron_permutation(a_ndim + b_ndim); + + std::vector a_axes(a_ndim); + std::vector b_axes(b_ndim); + std::iota(a_axes.begin(), a_axes.end(), 0); + std::iota(b_axes.begin(), b_axes.end(), 0 + a_ndim); + + if (b_ndim <= a_ndim) { + for (int64_t i = 0; i <= ndim_diff; i++) { + kron_permutation[i] = a_axes[i]; + } + } else { + for (int64_t i = 0; i < ndim_diff; i++) { + kron_permutation[i] = b_axes[i]; + } } - int64_t res_ndim = a_ndim > b_ndim ? a_ndim : b_ndim; + for (int64_t i = 0, j = 0; i < std::min(a_ndim, b_ndim); i++, j += 2) { + kron_permutation[a_ndim + b_ndim - 1 - j] = b_axes[b_ndim - 1 - i]; + kron_permutation[a_ndim + b_ndim - 1 - j - 1] = a_axes[a_ndim - 1 - i]; + } + + int64_t res_ndim = std::max(a_ndim, b_ndim); std::vector res_shape(res_ndim); - for (int64_t i = 0; i < res_ndim; i++) { - if (a_ndim == b_ndim) { - res_shape[res_ndim-1-i] = a_sizes[a_ndim-1-i] * b_sizes[b_ndim-1-i]; + if (a_ndim > b_ndim) { + for (int64_t i = 0; i < ndim_diff; i++) { + res_shape[i] = a_sizes[i]; } - else if (i >= b_ndim) { - res_shape[res_ndim-1-i] = a_sizes[a_ndim-1-i]; + for (int64_t i = 0; i < b_ndim; i++) { + res_shape[ndim_diff + i] = a_sizes[ndim_diff + i] * b_sizes[i]; } - else { - res_shape[res_ndim-1-i] = b_sizes[b_ndim-1-i]; + } else { + for (int64_t i = 0; i < ndim_diff; i++) { + res_shape[i] = b_sizes[i]; + } + for (int64_t i = 0; i < a_ndim; i++) { + res_shape[ndim_diff + i] = b_sizes[ndim_diff + i] * a_sizes[i]; } } From 5a0dcf945adea56c3f42f8196e4a6641603cbec9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 28 Sep 2020 10:15:26 -0500 Subject: [PATCH 04/28] Rewrote using ternary --- aten/src/ATen/native/LinearAlgebra.cpp | 43 +++++++++----------------- 1 file changed, 14 insertions(+), 29 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 5e8f7630186d..629f553f9319 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1708,51 +1708,36 @@ Tensor kron(const Tensor& a, const Tensor& b) { int64_t min_ndim = std::min(a_ndim, b_ndim); int64_t ndim_diff = std::abs(a_ndim - b_ndim); - std::vector kron_permutation(a_ndim + b_ndim); - std::vector a_axes(a_ndim); std::vector b_axes(b_ndim); std::iota(a_axes.begin(), a_axes.end(), 0); std::iota(b_axes.begin(), b_axes.end(), 0 + a_ndim); - if (b_ndim <= a_ndim) { - for (int64_t i = 0; i <= ndim_diff; i++) { - kron_permutation[i] = a_axes[i]; - } - } else { - for (int64_t i = 0; i < ndim_diff; i++) { - kron_permutation[i] = b_axes[i]; - } + bool is_a_larger = a_ndim >= b_ndim; + std::vector kron_permutation(a_ndim + b_ndim); + for (int64_t i = 0; i < ndim_diff; i++) { + kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i]; } - - for (int64_t i = 0, j = 0; i < std::min(a_ndim, b_ndim); i++, j += 2) { + for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) { kron_permutation[a_ndim + b_ndim - 1 - j] = b_axes[b_ndim - 1 - i]; kron_permutation[a_ndim + b_ndim - 1 - j - 1] = a_axes[a_ndim - 1 - i]; } - int64_t res_ndim = std::max(a_ndim, b_ndim); - std::vector res_shape(res_ndim); - if (a_ndim > b_ndim) { - for (int64_t i = 0; i < ndim_diff; i++) { - res_shape[i] = a_sizes[i]; - } - for (int64_t i = 0; i < b_ndim; i++) { - res_shape[ndim_diff + i] = a_sizes[ndim_diff + i] * b_sizes[i]; - } - } else { - for (int64_t i = 0; i < ndim_diff; i++) { - res_shape[i] = b_sizes[i]; - } - for (int64_t i = 0; i < a_ndim; i++) { - res_shape[ndim_diff + i] = b_sizes[ndim_diff + i] * a_sizes[i]; - } + std::vector result_shape(std::max(a_ndim, b_ndim)); + for (int64_t i = 0; i < ndim_diff; i++) { + result_shape[i] = is_a_larger ? a_sizes[i] : b_sizes[i]; + } + for (int64_t i = 0; i < min_ndim; i++) { + result_shape[ndim_diff + i] = is_a_larger + ? a_sizes[ndim_diff + i] * b_sizes[i] + : b_sizes[ndim_diff + i] * a_sizes[i]; } Tensor result = at::tensordot(a, b, {}, {}); // Step 2: now permute result result = result.permute(kron_permutation); // Step 3: reshape - result = result.reshape(res_shape); + result = result.reshape(result_shape); return result; } From 4c30a9af4af6d44b3a7239b60478544b9a65d7c4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 28 Sep 2020 11:46:23 -0500 Subject: [PATCH 05/28] flake8 --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index f89d04bbec81..3123cab323bf 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -193,7 +193,7 @@ def run_test_case(a_shape, b_shape): actual = torch.kron(a, b) self.assertEqual(actual, expected) - shapes = [(4,), (2, 2,), (1,2,3,), (1,2,3,3,)] + shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): run_test_case(a_shape, b_shape) From 3716b6aadaaa7b25d2f7d3b6e67b5c816df480b0 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 28 Sep 2020 12:10:48 -0500 Subject: [PATCH 06/28] Added documentation --- torch/_torch_docs.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 7b00ddbd1505..5469205fdff8 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3500,6 +3500,33 @@ def merge_dicts(*dicts): RuntimeError: bool value of Tensor with no values is ambiguous """.format(**common_args)) +add_docstr(torch.kron, + r""" +kron(input, other) -> Tensor + +Computes the Kronecker product of :attr:`input` and :attr:`other`. + +If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`other` is a +:math:`(k \times l)` tensor, the result will be a :math:`(n*k \times m*l)` tensor. + +Arguments: + input (Tensor): the first tensor to be multiplied + other (Tensor): the second tensor to be multiplied + +Returns: + Tensor: A tensor made of blocks of the second tensor scaled by the first. + +Example:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) +""") + add_docstr(torch.kthvalue, r""" kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) From 9931ad951ac2ce963a20a9ee1a6e0ee6f578acb9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 29 Sep 2020 01:57:50 -0500 Subject: [PATCH 07/28] Added fp32, complex dtypes --- test/test_linalg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 13961a053ffd..604350b27fc0 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -181,7 +181,7 @@ def test_det(self, device, dtype): op(t) @unittest.skipIf(not TEST_NUMPY, "NumPy not found") - @dtypes(torch.double) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_kron(self, device, dtype): def run_test_case(a_shape, b_shape): From 18e2767bb8085d2952b3f10898cdc11a24183701 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 29 Sep 2020 09:22:19 -0500 Subject: [PATCH 08/28] Added overrides entry for torch.kron --- torch/overrides.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/overrides.py b/torch/overrides.py index 352ba76b9593..2b02e4aacd2d 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -419,6 +419,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=True, length=None, return_complex=False: -1), torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, + torch.kron: lambda input, other: -1, torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, torch.lcm: lambda input, other, out=None: -1, From e66ab58f262762dc7273b2a9424594e3bca99ddf Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 08:21:19 +0000 Subject: [PATCH 09/28] Renamed a, b -> self, other --- aten/src/ATen/native/LinearAlgebra.cpp | 46 +++++++++++++------------- 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 629f553f9319..794d93f8f206 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1689,51 +1689,51 @@ Tensor chain_matmul(TensorList matrices) { /* Calculates the Kronecker product between two Tensors. */ -Tensor kron(const Tensor& a, const Tensor& b) { +Tensor kron(const Tensor& self, const Tensor& other) { /* We can obtain the kron result using tensordot or einsum. - In einsum notation suppose we have `a` with dim 4 and `b` with dim 2 + In einsum notation suppose we have `self` with dim 4 and `other` with dim 2 the result of below tensordot is in einsum 0123, 45 -> 012345. To obtain the correct kron we need to permute and reshape the array. The permutation rule is the following: going from right to left take axes in turn to form the permutation with our example the correct permutation is 012435 and - the kron shape is (shape_a[0], shape_a[1], shape_a[3]*shape_b[0], - shape_a[4]*shape_b[1]) + the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0], + shape_self[4]*shape_other[1]) */ - std::vector a_sizes = a.sizes().vec(); - std::vector b_sizes = b.sizes().vec(); - int64_t a_ndim = a.dim(); - int64_t b_ndim = b.dim(); - int64_t min_ndim = std::min(a_ndim, b_ndim); - int64_t ndim_diff = std::abs(a_ndim - b_ndim); - - std::vector a_axes(a_ndim); - std::vector b_axes(b_ndim); + std::vector self_sizes = self.sizes().vec(); + std::vector other_sizes = other.sizes().vec(); + int64_t self_ndim = self.dim(); + int64_t other_ndim = other.dim(); + int64_t min_ndim = std::min(self_ndim, other_ndim); + int64_t ndim_diff = std::abs(self_ndim - other_ndim); + + std::vector a_axes(self_ndim); + std::vector b_axes(other_ndim); std::iota(a_axes.begin(), a_axes.end(), 0); - std::iota(b_axes.begin(), b_axes.end(), 0 + a_ndim); + std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); - bool is_a_larger = a_ndim >= b_ndim; - std::vector kron_permutation(a_ndim + b_ndim); + bool is_a_larger = self_ndim >= other_ndim; + std::vector kron_permutation(self_ndim + other_ndim); for (int64_t i = 0; i < ndim_diff; i++) { kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i]; } for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) { - kron_permutation[a_ndim + b_ndim - 1 - j] = b_axes[b_ndim - 1 - i]; - kron_permutation[a_ndim + b_ndim - 1 - j - 1] = a_axes[a_ndim - 1 - i]; + kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i]; + kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i]; } - std::vector result_shape(std::max(a_ndim, b_ndim)); + std::vector result_shape(std::max(self_ndim, other_ndim)); for (int64_t i = 0; i < ndim_diff; i++) { - result_shape[i] = is_a_larger ? a_sizes[i] : b_sizes[i]; + result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; } for (int64_t i = 0; i < min_ndim; i++) { result_shape[ndim_diff + i] = is_a_larger - ? a_sizes[ndim_diff + i] * b_sizes[i] - : b_sizes[ndim_diff + i] * a_sizes[i]; + ? self_sizes[ndim_diff + i] * other_sizes[i] + : other_sizes[ndim_diff + i] * self_sizes[i]; } - Tensor result = at::tensordot(a, b, {}, {}); + Tensor result = at::tensordot(self, other, {}, {}); // Step 2: now permute result result = result.permute(kron_permutation); // Step 3: reshape From dfb00d3e30d614b712a3aa9530dca5739b787804 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 09:17:22 +0000 Subject: [PATCH 10/28] Added out= variant --- aten/src/ATen/native/LinearAlgebra.cpp | 33 ++++++++++++++++++++++ aten/src/ATen/native/native_functions.yaml | 4 ++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 794d93f8f206..3b516fb82bd1 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1742,5 +1742,38 @@ Tensor kron(const Tensor& self, const Tensor& other) { return result; } +Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { + std::vector self_sizes = self.sizes().vec(); + std::vector other_sizes = other.sizes().vec(); + int64_t self_ndim = self.dim(); + int64_t other_ndim = other.dim(); + int64_t min_ndim = std::min(self_ndim, other_ndim); + int64_t ndim_diff = std::abs(self_ndim - other_ndim); + + std::vector a_axes(self_ndim); + std::vector b_axes(other_ndim); + std::iota(a_axes.begin(), a_axes.end(), 0); + std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); + + bool is_a_larger = self_ndim >= other_ndim; + + std::vector expected_result_shape(std::max(self_ndim, other_ndim)); + for (int64_t i = 0; i < ndim_diff; i++) { + expected_result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; + } + for (int64_t i = 0; i < min_ndim; i++) { + expected_result_shape[ndim_diff + i] = is_a_larger + ? self_sizes[ndim_diff + i] * other_sizes[i] + : other_sizes[ndim_diff + i] * self_sizes[i]; + } + + TORCH_CHECK(result.sizes().equals(expected_result_shape), + "Expected result tensor to have size of ", expected_result_shape, ", but got tensor of size ", result.sizes()); + + Tensor result_tmp = at::kron(self, other); + result.copy_(result_tmp); + return result; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 2c16a12d69ec..e3a5ed74169d 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1953,7 +1953,9 @@ CUDA: kl_div_backward_cuda - func: kron(Tensor self, Tensor other) -> Tensor - variants: function + variants: function, method + +- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full From f886d197c7cbc914ab11bb9be78ad9c7d5f68695 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 09:17:55 +0000 Subject: [PATCH 11/28] Updated documentation: added note on real and complex support, added out kwarg --- torch/_torch_docs.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 5469205fdff8..d9ac6b897465 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3502,17 +3502,22 @@ def merge_dicts(*dicts): add_docstr(torch.kron, r""" -kron(input, other) -> Tensor +kron(input, other, *, out=None) -> Tensor Computes the Kronecker product of :attr:`input` and :attr:`other`. If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`other` is a :math:`(k \times l)` tensor, the result will be a :math:`(n*k \times m*l)` tensor. +Supports real and complex inputs. + Arguments: input (Tensor): the first tensor to be multiplied other (Tensor): the second tensor to be multiplied +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + Returns: Tensor: A tensor made of blocks of the second tensor scaled by the first. From c60487b1f8a06df489e132f742d32c3fbf62e11b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 09:18:16 +0000 Subject: [PATCH 12/28] Added empty and out= test cases --- test/test_linalg.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index 604350b27fc0..c91c1bb1109c 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -180,7 +180,6 @@ def test_det(self, device, dtype): with self.assertRaises(RuntimeError): op(t) - @unittest.skipIf(not TEST_NUMPY, "NumPy not found") @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_kron(self, device, dtype): @@ -189,13 +188,37 @@ def run_test_case(a_shape, b_shape): b = torch.rand(b_shape, dtype=dtype, device=device) expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) - actual = torch.kron(a, b) - self.assertEqual(actual, expected) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): run_test_case(a_shape, b_shape) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_empty(self, device, dtype): + + def run_test_case(empty_shape): + a = torch.eye(3, dtype=dtype, device=device) + b = torch.empty(empty_shape, dtype=dtype, device=device) + result = torch.kron(a, b) + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(result, expected) + + # NumPy doesn't work if the first argument is empty + result = torch.kron(b, a) + self.assertEqual(result.shape, expected.shape) + + empty_shapes = [(0,), (2, 0), (1, 0, 3)] + for empty_shape in empty_shapes: + run_test_case(empty_shape) + # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation @skipCUDAIfNoMagma From 76354a57cf3d76ea77b6b2d2ece9e5419d055d87 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 09:18:31 +0000 Subject: [PATCH 13/28] Added entry to common_methods_invocations.py --- torch/testing/_internal/common_methods_invocations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index cf4ff7f31fdd..03a3f31b03aa 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1394,6 +1394,7 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('kron', (S, S), ((M, L),)) ] def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None): From 1f9484971ec37f6c9022788788e36c006c491ab9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 09:23:07 +0000 Subject: [PATCH 14/28] Added test for error with incorrect out= --- test/test_linalg.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index c91c1bb1109c..16f30da90387 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -219,6 +219,16 @@ def run_test_case(empty_shape): for empty_shape in empty_shapes: run_test_case(empty_shape) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_errors(self, device, dtype): + + # out tensor should have the correct resulting shape + a = torch.eye(3, dtype=dtype, device=device) + b = torch.ones((2, 2), dtype=dtype, device=device) + out = torch.empty_like(a) + with self.assertRaisesRegex(RuntimeError, r'Expected result tensor to have size of'): + ans = torch.kron(a, b, out=out) + # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation @skipCUDAIfNoMagma From 5e2f5e891acb5cffe0db700ef052b03cde03d1e8 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Tue, 20 Oct 2020 13:43:20 +0000 Subject: [PATCH 15/28] Added tensor_docs entry --- torch/_tensor_docs.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 208cd5805c4b..c176022b622e 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1892,6 +1892,13 @@ def add_docstr_all(method, docstr): """) +add_docstr_all('kron', + r""" +kron(tensor2) -> Tensor + +See :func:`torch.kron` +""") + add_docstr_all('kthvalue', r""" kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) From 71ceaab9449856614d8ba47e6232879d72af14fb Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 11:31:59 +0000 Subject: [PATCH 16/28] tensor2 -> other --- torch/_tensor_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index c176022b622e..8e1e9d699422 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1894,7 +1894,7 @@ def add_docstr_all(method, docstr): add_docstr_all('kron', r""" -kron(tensor2) -> Tensor +kron(other) -> Tensor See :func:`torch.kron` """) From 7786e4c6aa622643bab6539894ad5ecb511a687b Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 12:28:11 +0000 Subject: [PATCH 17/28] Updated documentation Added mathematical definition of the Kronecker product. Added another example. --- torch/_torch_docs.py | 30 +++++++++++++++++++++++------- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 31b140823f2a..ec8695926dd2 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3565,16 +3565,24 @@ def merge_dicts(*dicts): r""" kron(input, other, *, out=None) -> Tensor -Computes the Kronecker product of :attr:`input` and :attr:`other`. +Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. -If :attr:`input` is a :math:`(n \times m)` tensor, :attr:`other` is a -:math:`(k \times l)` tensor, the result will be a :math:`(n*k \times m*l)` tensor. +If :attr:`input` is a :math:`(m \times n)` tensor and :attr:`other` is a +:math:`(p \times q)` tensor, the result will be a :math:`(p*m \times q*n)` block tensor: +.. math:: + \mathbf{A} \otimes \mathbf{B}=\left[\begin{array}{ccc} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} + \end{array}\right], + +where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. -Supports real and complex inputs. +Supports real-valued and complex-valued inputs. Arguments: - input (Tensor): the first tensor to be multiplied - other (Tensor): the second tensor to be multiplied + input (Tensor) + other (Tensor) Keyword args: out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` @@ -3582,7 +3590,7 @@ def merge_dicts(*dicts): Returns: Tensor: A tensor made of blocks of the second tensor scaled by the first. -Example:: +Examples:: >>> mat1 = torch.eye(2) >>> mat2 = torch.ones(2, 2) @@ -3591,6 +3599,14 @@ def merge_dicts(*dicts): [1., 1., 0., 0.], [0., 0., 1., 1.], [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) """) add_docstr(torch.kthvalue, From 50cb9beefa7275afa9de2cffe8005650d52e18a9 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 12:30:23 +0000 Subject: [PATCH 18/28] Removed returns block in docs --- torch/_torch_docs.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index ec8695926dd2..1ce18a3ddc67 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3587,9 +3587,6 @@ def merge_dicts(*dicts): Keyword args: out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` -Returns: - Tensor: A tensor made of blocks of the second tensor scaled by the first. - Examples:: >>> mat1 = torch.eye(2) From 15985b3b9ac91e706672e92a81e1ff1868df7985 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 19:12:42 +0000 Subject: [PATCH 19/28] Added dispatch: Math --- aten/src/ATen/native/native_functions.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 21129048b95f..81a76188aac4 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2134,8 +2134,12 @@ - func: kron(Tensor self, Tensor other) -> Tensor variants: function, method + dispatch: + Math: kron - func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: kron_out - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full From 3f2e70da05a9cfad4adb2e721523ad2a1489d56f Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 19:20:13 +0000 Subject: [PATCH 20/28] Use at::native::resize_output in kron_out --- aten/src/ATen/native/LinearAlgebra.cpp | 29 +++----------------------- test/test_linalg.py | 19 ++++++++++++----- 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 4c47a63de58e..1bb4914d44e5 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1746,34 +1746,11 @@ Tensor kron(const Tensor& self, const Tensor& other) { } Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { - std::vector self_sizes = self.sizes().vec(); - std::vector other_sizes = other.sizes().vec(); - int64_t self_ndim = self.dim(); - int64_t other_ndim = other.dim(); - int64_t min_ndim = std::min(self_ndim, other_ndim); - int64_t ndim_diff = std::abs(self_ndim - other_ndim); - - std::vector a_axes(self_ndim); - std::vector b_axes(other_ndim); - std::iota(a_axes.begin(), a_axes.end(), 0); - std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); - - bool is_a_larger = self_ndim >= other_ndim; - - std::vector expected_result_shape(std::max(self_ndim, other_ndim)); - for (int64_t i = 0; i < ndim_diff; i++) { - expected_result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; - } - for (int64_t i = 0; i < min_ndim; i++) { - expected_result_shape[ndim_diff + i] = is_a_larger - ? self_sizes[ndim_diff + i] * other_sizes[i] - : other_sizes[ndim_diff + i] * self_sizes[i]; - } - - TORCH_CHECK(result.sizes().equals(expected_result_shape), - "Expected result tensor to have size of ", expected_result_shape, ", but got tensor of size ", result.sizes()); + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); Tensor result_tmp = at::kron(self, other); + at::native::resize_output(result, result_tmp.sizes()); result.copy_(result_tmp); return result; } diff --git a/test/test_linalg.py b/test/test_linalg.py index c187475d2aa9..a3f467b2f9e5 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -1,6 +1,7 @@ import torch import unittest import itertools +import warnings from math import inf, nan, isnan from random import randrange @@ -221,14 +222,22 @@ def run_test_case(empty_shape): run_test_case(empty_shape) @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) - def test_kron_errors(self, device, dtype): - - # out tensor should have the correct resulting shape + def test_kron_errors_and_warnings(self, device, dtype): + # if non-empty out tensor with wrong shape is passed a warning is given a = torch.eye(3, dtype=dtype, device=device) b = torch.ones((2, 2), dtype=dtype, device=device) out = torch.empty_like(a) - with self.assertRaisesRegex(RuntimeError, r'Expected result tensor to have size of'): - ans = torch.kron(a, b, out=out) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.kron(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.kron(a, b, out=out) # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation From a980bb45c54c90d8abbdad44b1873e11b1bf92cb Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Thu, 22 Oct 2020 19:28:33 +0000 Subject: [PATCH 21/28] Added non-contiguous test cases --- test/test_linalg.py | 44 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/test/test_linalg.py b/test/test_linalg.py index a3f467b2f9e5..a9f18ba57824 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -203,6 +203,50 @@ def run_test_case(a_shape, b_shape): for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): run_test_case(a_shape, b_shape) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_non_contiguous(self, device, dtype): + + def run_test_transposed(a_shape, b_shape): + # check for transposed case + a = torch.rand(a_shape, dtype=dtype, device=device).transpose(-2, -1) + b = torch.rand(b_shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(result.transpose(-2, -1).shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + def run_test_skipped_elements(a_shape, b_shape): + # check for transposed case + a = torch.rand(a_shape, dtype=dtype, device=device)[::2] + b = torch.rand(b_shape, dtype=dtype, device=device)[::2] + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + run_test_transposed(a_shape, b_shape) + run_test_skipped_elements(a_shape, b_shape) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) def test_kron_empty(self, device, dtype): From e22a079a895c1c612bb2fa874f8ece91502eb066 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sat, 24 Oct 2020 11:34:16 +0000 Subject: [PATCH 22/28] Use bmatrix instead of array in matrix example --- torch/_torch_docs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b1b7cdc06641..29553ea1484f 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3602,12 +3602,12 @@ def merge_dicts(*dicts): If :attr:`input` is a :math:`(m \times n)` tensor and :attr:`other` is a :math:`(p \times q)` tensor, the result will be a :math:`(p*m \times q*n)` block tensor: + .. math:: - \mathbf{A} \otimes \mathbf{B}=\left[\begin{array}{ccc} + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ \vdots & \ddots & \vdots \\ - a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} - \end{array}\right], + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. From fa90e7160957709b0569addb902d7c32f0578425 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Sat, 24 Oct 2020 11:39:05 +0000 Subject: [PATCH 23/28] Fixed non-contiguous tests --- test/test_linalg.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_linalg.py b/test/test_linalg.py index a9f18ba57824..b518462ef1cd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -226,8 +226,8 @@ def run_test_transposed(a_shape, b_shape): def run_test_skipped_elements(a_shape, b_shape): # check for transposed case - a = torch.rand(a_shape, dtype=dtype, device=device)[::2] - b = torch.rand(b_shape, dtype=dtype, device=device)[::2] + a = torch.rand(2 * a_shape[0], *a_shape[1:], dtype=dtype, device=device)[::2] + b = torch.rand(2 * b_shape[0], *b_shape[1:], dtype=dtype, device=device)[::2] self.assertFalse(a.is_contiguous()) self.assertFalse(b.is_contiguous()) @@ -242,9 +242,9 @@ def run_test_skipped_elements(a_shape, b_shape): self.assertEqual(ans, out) self.assertEqual(ans, result) - shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] + shapes = [(2, 2), (2, 2, 3), (2, 2, 3, 3)] for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): - run_test_transposed(a_shape, b_shape) + # run_test_transposed(a_shape, b_shape) run_test_skipped_elements(a_shape, b_shape) @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) From 5ad2f45d0c3cd86ed86fe5b676ec5db856d78621 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 26 Oct 2020 10:14:05 -0500 Subject: [PATCH 24/28] Updated comment in LinearAlgebra.cpp --- aten/src/ATen/native/LinearAlgebra.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 1bb4914d44e5..7a8213dc6fcd 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1694,7 +1694,7 @@ Calculates the Kronecker product between two Tensors. */ Tensor kron(const Tensor& self, const Tensor& other) { /* - We can obtain the kron result using tensordot or einsum. + We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot. In einsum notation suppose we have `self` with dim 4 and `other` with dim 2 the result of below tensordot is in einsum 0123, 45 -> 012345. To obtain the correct kron we need to permute and reshape the array. From 1cf37ff2412f0cd19ac3f50b779e25ab5d9f4761 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 26 Oct 2020 10:51:01 -0500 Subject: [PATCH 25/28] Updated docs for general n-dim tensors --- torch/_torch_docs.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 29553ea1484f..0e5ac9dd6b03 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3600,19 +3600,31 @@ def merge_dicts(*dicts): Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. -If :attr:`input` is a :math:`(m \times n)` tensor and :attr:`other` is a -:math:`(p \times q)` tensor, the result will be a :math:`(p*m \times q*n)` block tensor: +If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a +:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` block tensor with the following entries: .. math:: - \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} - a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ - \vdots & \ddots & \vdots \\ - a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, -where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. +where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. +If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. Supports real-valued and complex-valued inputs. +.. note:: + The Kronecker product is commonly defined only for matrices. + If :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + Arguments: input (Tensor) other (Tensor) From 60db5237ba81122237b0583a51256f7afbd144b6 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 26 Oct 2020 10:54:49 -0500 Subject: [PATCH 26/28] Changed 'block tensor' -> 'tensor' --- torch/_torch_docs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 0e5ac9dd6b03..b08c26f124a7 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3602,7 +3602,7 @@ def merge_dicts(*dicts): If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a :math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a -:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` block tensor with the following entries: +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: .. math:: (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, From 5888b0ea6a72bf9c214699c99668c2439ea35223 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 26 Oct 2020 11:09:34 -0500 Subject: [PATCH 27/28] Fix too long line --- torch/_torch_docs.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index b08c26f124a7..f2b186984260 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3605,7 +3605,8 @@ def merge_dicts(*dicts): :math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: .. math:: - (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. From 719a68749d9271b6fae0bb834611f770b0fc53c4 Mon Sep 17 00:00:00 2001 From: Ivan Yashchuk Date: Mon, 2 Nov 2020 04:05:11 -0600 Subject: [PATCH 28/28] Changed the note according to Mike's suggestion --- torch/_torch_docs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f2b186984260..9c220dc259a0 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3614,8 +3614,8 @@ def merge_dicts(*dicts): Supports real-valued and complex-valued inputs. .. note:: - The Kronecker product is commonly defined only for matrices. - If :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: .. math::