Skip to content

Commit

Permalink
Added Kronecker product of tensors (torch.kron) (#45358)
Browse files Browse the repository at this point in the history
Summary:
This PR adds a function for calculating the Kronecker product of tensors.
The implementation is based on `at::tensordot` with permutations and reshape.
Tests pass.

TODO:

- [x] Add more test cases
- [x] Write documentation
- [x] Add entry `common_methods_invokations.py`

Ref. #42666

Pull Request resolved: #45358

Reviewed By: mrshenli

Differential Revision: D24680755

Pulled By: mruberry

fbshipit-source-id: b1f8694589349986c3abfda3dc1971584932b3fa
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Nov 3, 2020
1 parent 32b66b0 commit f276ab5
Show file tree
Hide file tree
Showing 8 changed files with 244 additions and 0 deletions.
66 changes: 66 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -1738,5 +1738,71 @@ Tensor chain_matmul(TensorList matrices) {
}
}

/*
Calculates the Kronecker product between two Tensors.
*/
Tensor kron(const Tensor& self, const Tensor& other) {
/*
We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot.
In einsum notation suppose we have `self` with dim 4 and `other` with dim 2
the result of below tensordot is in einsum 0123, 45 -> 012345.
To obtain the correct kron we need to permute and reshape the array.
The permutation rule is the following: going from right to left
take axes in turn to form the permutation
with our example the correct permutation is 012435 and
the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0],
shape_self[4]*shape_other[1])
*/
std::vector<int64_t> self_sizes = self.sizes().vec();
std::vector<int64_t> other_sizes = other.sizes().vec();
int64_t self_ndim = self.dim();
int64_t other_ndim = other.dim();
int64_t min_ndim = std::min(self_ndim, other_ndim);
int64_t ndim_diff = std::abs(self_ndim - other_ndim);

std::vector<int64_t> a_axes(self_ndim);
std::vector<int64_t> b_axes(other_ndim);
std::iota(a_axes.begin(), a_axes.end(), 0);
std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim);

bool is_a_larger = self_ndim >= other_ndim;
std::vector<int64_t> kron_permutation(self_ndim + other_ndim);
for (int64_t i = 0; i < ndim_diff; i++) {
kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i];
}
for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) {
kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i];
kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i];
}

std::vector<int64_t> result_shape(std::max(self_ndim, other_ndim));
for (int64_t i = 0; i < ndim_diff; i++) {
result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i];
}
for (int64_t i = 0; i < min_ndim; i++) {
result_shape[ndim_diff + i] = is_a_larger
? self_sizes[ndim_diff + i] * other_sizes[i]
: other_sizes[ndim_diff + i] * self_sizes[i];
}

Tensor result = at::tensordot(self, other, {}, {});
// Step 2: now permute result
result = result.permute(kron_permutation);
// Step 3: reshape
result = result.reshape(result_shape);

return result;
}

Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());

Tensor result_tmp = at::kron(self, other);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}

} // namespace native
} // namespace at
9 changes: 9 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2142,6 +2142,15 @@
CPU: kl_div_backward_cpu
CUDA: kl_div_backward_cuda

- func: kron(Tensor self, Tensor other) -> Tensor
variants: function, method
dispatch:
Math: kron

- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: kron_out

- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -459,6 +459,7 @@ Other Operations
flip
fliplr
flipud
kron
rot90
gcd
histc
Expand Down
101 changes: 101 additions & 0 deletions test/test_linalg.py
Expand Up @@ -183,6 +183,107 @@ def test_det(self, device, dtype):
with self.assertRaises(RuntimeError):
op(t)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron(self, device, dtype):

def run_test_case(a_shape, b_shape):
a = torch.rand(a_shape, dtype=dtype, device=device)
b = torch.rand(b_shape, dtype=dtype, device=device)

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty_like(result)
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
run_test_case(a_shape, b_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_non_contiguous(self, device, dtype):

def run_test_transposed(a_shape, b_shape):
# check for transposed case
a = torch.rand(a_shape, dtype=dtype, device=device).transpose(-2, -1)
b = torch.rand(b_shape, dtype=dtype, device=device).transpose(-2, -1)
self.assertFalse(a.is_contiguous())
self.assertFalse(b.is_contiguous())

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty(result.transpose(-2, -1).shape, dtype=dtype, device=device).transpose(-2, -1)
self.assertFalse(out.is_contiguous())
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

def run_test_skipped_elements(a_shape, b_shape):
# check for transposed case
a = torch.rand(2 * a_shape[0], *a_shape[1:], dtype=dtype, device=device)[::2]
b = torch.rand(2 * b_shape[0], *b_shape[1:], dtype=dtype, device=device)[::2]
self.assertFalse(a.is_contiguous())
self.assertFalse(b.is_contiguous())

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2]
self.assertFalse(out.is_contiguous())
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

shapes = [(2, 2), (2, 2, 3), (2, 2, 3, 3)]
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
# run_test_transposed(a_shape, b_shape)
run_test_skipped_elements(a_shape, b_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_empty(self, device, dtype):

def run_test_case(empty_shape):
a = torch.eye(3, dtype=dtype, device=device)
b = torch.empty(empty_shape, dtype=dtype, device=device)
result = torch.kron(a, b)
expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
self.assertEqual(result, expected)

# NumPy doesn't work if the first argument is empty
result = torch.kron(b, a)
self.assertEqual(result.shape, expected.shape)

empty_shapes = [(0,), (2, 0), (1, 0, 3)]
for empty_shape in empty_shapes:
run_test_case(empty_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_errors_and_warnings(self, device, dtype):
# if non-empty out tensor with wrong shape is passed a warning is given
a = torch.eye(3, dtype=dtype, device=device)
b = torch.ones((2, 2), dtype=dtype, device=device)
out = torch.empty_like(a)
with warnings.catch_warnings(record=True) as w:
# Trigger warning
torch.kron(a, b, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))

# dtypes should match
out = torch.empty_like(a).to(torch.int)
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
torch.kron(a, b, out=out)

# This test confirms that torch.linalg.norm's dtype argument works
# as expected, according to the function's documentation
@skipCUDAIfNoMagma
Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -1956,6 +1956,13 @@ def add_docstr_all(method, docstr):
""")

add_docstr_all('kron',
r"""
kron(other) -> Tensor
See :func:`torch.kron`
""")

add_docstr_all('kthvalue',
r"""
kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor)
Expand Down
58 changes: 58 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -3669,6 +3669,64 @@ def merge_dicts(*dicts):
RuntimeError: bool value of Tensor with no values is ambiguous
""".format(**common_args))

add_docstr(torch.kron,
r"""
kron(input, other, *, out=None) -> Tensor
Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`.
If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a
:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a
:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries:
.. math::
(\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} =
\text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n},
where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`.
If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions.
Supports real-valued and complex-valued inputs.
.. note::
This function generalizes the typical definition of the Kronecker product for two matrices to two tensors,
as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a
:math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix:
.. math::
\mathbf{A} \otimes \mathbf{B}=\begin{bmatrix}
a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\
\vdots & \ddots & \vdots \\
a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix}
where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`.
Arguments:
input (Tensor)
other (Tensor)
Keyword args:
out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None``
Examples::
>>> mat1 = torch.eye(2)
>>> mat2 = torch.ones(2, 2)
>>> torch.kron(mat1, mat2)
tensor([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]])
>>> mat1 = torch.eye(2)
>>> mat2 = torch.arange(1, 5).reshape(2, 2)
>>> torch.kron(mat1, mat2)
tensor([[1., 2., 0., 0.],
[3., 4., 0., 0.],
[0., 0., 1., 2.],
[0., 0., 3., 4.]])
""")

add_docstr(torch.kthvalue,
r"""
kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor)
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -417,6 +417,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
normalized=False, onesided=None, length=None, return_complex=False: -1),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
torch.lcm: lambda input, other, out=None: -1,
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1462,6 +1462,7 @@ def method_tests():
('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3],
torch.LongTensor([0, 0, 2])]),), 'adv_index_var'),
('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()),
('kron', (S, S), ((M, L),))
]

def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None):
Expand Down

0 comments on commit f276ab5

Please sign in to comment.