diff --git a/aten/src/ATen/native/LinearAlgebra.cpp b/aten/src/ATen/native/LinearAlgebra.cpp index 8796657dc293..1777f00ee68a 100644 --- a/aten/src/ATen/native/LinearAlgebra.cpp +++ b/aten/src/ATen/native/LinearAlgebra.cpp @@ -1738,5 +1738,71 @@ Tensor chain_matmul(TensorList matrices) { } } +/* +Calculates the Kronecker product between two Tensors. +*/ +Tensor kron(const Tensor& self, const Tensor& other) { + /* + We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot. + In einsum notation suppose we have `self` with dim 4 and `other` with dim 2 + the result of below tensordot is in einsum 0123, 45 -> 012345. + To obtain the correct kron we need to permute and reshape the array. + The permutation rule is the following: going from right to left + take axes in turn to form the permutation + with our example the correct permutation is 012435 and + the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0], + shape_self[4]*shape_other[1]) + */ + std::vector self_sizes = self.sizes().vec(); + std::vector other_sizes = other.sizes().vec(); + int64_t self_ndim = self.dim(); + int64_t other_ndim = other.dim(); + int64_t min_ndim = std::min(self_ndim, other_ndim); + int64_t ndim_diff = std::abs(self_ndim - other_ndim); + + std::vector a_axes(self_ndim); + std::vector b_axes(other_ndim); + std::iota(a_axes.begin(), a_axes.end(), 0); + std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim); + + bool is_a_larger = self_ndim >= other_ndim; + std::vector kron_permutation(self_ndim + other_ndim); + for (int64_t i = 0; i < ndim_diff; i++) { + kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i]; + } + for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) { + kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i]; + kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i]; + } + + std::vector result_shape(std::max(self_ndim, other_ndim)); + for (int64_t i = 0; i < ndim_diff; i++) { + result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i]; + } + for (int64_t i = 0; i < min_ndim; i++) { + result_shape[ndim_diff + i] = is_a_larger + ? self_sizes[ndim_diff + i] * other_sizes[i] + : other_sizes[ndim_diff + i] * self_sizes[i]; + } + + Tensor result = at::tensordot(self, other, {}, {}); + // Step 2: now permute result + result = result.permute(kron_permutation); + // Step 3: reshape + result = result.reshape(result_shape); + + return result; +} + +Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) { + TORCH_CHECK(result.scalar_type() == self.scalar_type(), + "result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type()); + + Tensor result_tmp = at::kron(self, other); + at::native::resize_output(result, result_tmp.sizes()); + result.copy_(result_tmp); + return result; +} + } // namespace native } // namespace at diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c98b9423b25c..696551a12a99 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2142,6 +2142,15 @@ CPU: kl_div_backward_cpu CUDA: kl_div_backward_cuda +- func: kron(Tensor self, Tensor other) -> Tensor + variants: function, method + dispatch: + Math: kron + +- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + dispatch: + Math: kron_out + - func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) use_c10_dispatcher: full variants: function, method diff --git a/docs/source/torch.rst b/docs/source/torch.rst index e36a3f944a7a..6ba16aaca3e2 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -459,6 +459,7 @@ Other Operations flip fliplr flipud + kron rot90 gcd histc diff --git a/test/test_linalg.py b/test/test_linalg.py index d7de3841ab65..56c764e7fea1 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -183,6 +183,107 @@ def test_det(self, device, dtype): with self.assertRaises(RuntimeError): op(t) + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron(self, device, dtype): + + def run_test_case(a_shape, b_shape): + a = torch.rand(a_shape, dtype=dtype, device=device) + b = torch.rand(b_shape, dtype=dtype, device=device) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty_like(result) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + run_test_case(a_shape, b_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_non_contiguous(self, device, dtype): + + def run_test_transposed(a_shape, b_shape): + # check for transposed case + a = torch.rand(a_shape, dtype=dtype, device=device).transpose(-2, -1) + b = torch.rand(b_shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(result.transpose(-2, -1).shape, dtype=dtype, device=device).transpose(-2, -1) + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + def run_test_skipped_elements(a_shape, b_shape): + # check for transposed case + a = torch.rand(2 * a_shape[0], *a_shape[1:], dtype=dtype, device=device)[::2] + b = torch.rand(2 * b_shape[0], *b_shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(a.is_contiguous()) + self.assertFalse(b.is_contiguous()) + + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + result = torch.kron(a, b) + self.assertEqual(result, expected) + + # check the out= variant + out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2] + self.assertFalse(out.is_contiguous()) + ans = torch.kron(a, b, out=out) + self.assertEqual(ans, out) + self.assertEqual(ans, result) + + shapes = [(2, 2), (2, 2, 3), (2, 2, 3, 3)] + for a_shape, b_shape in itertools.product(shapes, reversed(shapes)): + # run_test_transposed(a_shape, b_shape) + run_test_skipped_elements(a_shape, b_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_empty(self, device, dtype): + + def run_test_case(empty_shape): + a = torch.eye(3, dtype=dtype, device=device) + b = torch.empty(empty_shape, dtype=dtype, device=device) + result = torch.kron(a, b) + expected = np.kron(a.cpu().numpy(), b.cpu().numpy()) + self.assertEqual(result, expected) + + # NumPy doesn't work if the first argument is empty + result = torch.kron(b, a) + self.assertEqual(result.shape, expected.shape) + + empty_shapes = [(0,), (2, 0), (1, 0, 3)] + for empty_shape in empty_shapes: + run_test_case(empty_shape) + + @dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128) + def test_kron_errors_and_warnings(self, device, dtype): + # if non-empty out tensor with wrong shape is passed a warning is given + a = torch.eye(3, dtype=dtype, device=device) + b = torch.ones((2, 2), dtype=dtype, device=device) + out = torch.empty_like(a) + with warnings.catch_warnings(record=True) as w: + # Trigger warning + torch.kron(a, b, out=out) + # Check warning occurs + self.assertEqual(len(w), 1) + self.assertTrue("An output with one or more elements was resized" in str(w[-1].message)) + + # dtypes should match + out = torch.empty_like(a).to(torch.int) + with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"): + torch.kron(a, b, out=out) + # This test confirms that torch.linalg.norm's dtype argument works # as expected, according to the function's documentation @skipCUDAIfNoMagma diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 12dd77497454..9339d805c1b9 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1956,6 +1956,13 @@ def add_docstr_all(method, docstr): """) +add_docstr_all('kron', + r""" +kron(other) -> Tensor + +See :func:`torch.kron` +""") + add_docstr_all('kthvalue', r""" kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor) diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 4f0de3335414..e3fc7acfa160 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3669,6 +3669,64 @@ def merge_dicts(*dicts): RuntimeError: bool value of Tensor with no values is ambiguous """.format(**common_args)) +add_docstr(torch.kron, + r""" +kron(input, other, *, out=None) -> Tensor + +Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`. + +If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a +:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a +:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries: + +.. math:: + (\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} = + \text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n}, + +where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`. +If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions. + +Supports real-valued and complex-valued inputs. + +.. note:: + This function generalizes the typical definition of the Kronecker product for two matrices to two tensors, + as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a + :math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix: + + .. math:: + \mathbf{A} \otimes \mathbf{B}=\begin{bmatrix} + a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\ + \vdots & \ddots & \vdots \\ + a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix} + + where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`. + +Arguments: + input (Tensor) + other (Tensor) + +Keyword args: + out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None`` + +Examples:: + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.ones(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 1., 0., 0.], + [1., 1., 0., 0.], + [0., 0., 1., 1.], + [0., 0., 1., 1.]]) + + >>> mat1 = torch.eye(2) + >>> mat2 = torch.arange(1, 5).reshape(2, 2) + >>> torch.kron(mat1, mat2) + tensor([[1., 2., 0., 0.], + [3., 4., 0., 0.], + [0., 0., 1., 2.], + [0., 0., 3., 4.]]) +""") + add_docstr(torch.kthvalue, r""" kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor) diff --git a/torch/overrides.py b/torch/overrides.py index 4fd334c911d8..88cf0b10868c 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -417,6 +417,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True, normalized=False, onesided=None, length=None, return_complex=False: -1), torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1, + torch.kron: lambda input, other: -1, torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1, torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1, torch.lcm: lambda input, other, out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 9b67d31e7d16..8bbd2bfdc944 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1460,6 +1460,7 @@ def method_tests(): ('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3], torch.LongTensor([0, 0, 2])]),), 'adv_index_var'), ('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()), + ('kron', (S, S), ((M, L),)) ] def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None):