Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Kronecker product of tensors (torch.kron) #45358

Closed
wants to merge 32 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
99a87b2
Added torch.kron
IvanYashchuk Sep 25, 2020
90f0159
Updated tests
IvanYashchuk Sep 28, 2020
98a2f49
Now kron permutation and reshape is correct
IvanYashchuk Sep 28, 2020
5a0dcf9
Rewrote using ternary
IvanYashchuk Sep 28, 2020
4c30a9a
flake8
IvanYashchuk Sep 28, 2020
3716b6a
Added documentation
IvanYashchuk Sep 28, 2020
8f3fb2a
Merge remote-tracking branch 'upstream/master' into linalg-kron
IvanYashchuk Sep 29, 2020
9931ad9
Added fp32, complex dtypes
IvanYashchuk Sep 29, 2020
18e2767
Added overrides entry for torch.kron
IvanYashchuk Sep 29, 2020
e66ab58
Renamed a, b -> self, other
IvanYashchuk Oct 20, 2020
dfb00d3
Added out= variant
IvanYashchuk Oct 20, 2020
f886d19
Updated documentation: added note on real and complex support, added …
IvanYashchuk Oct 20, 2020
c60487b
Added empty and out= test cases
IvanYashchuk Oct 20, 2020
76354a5
Added entry to common_methods_invocations.py
IvanYashchuk Oct 20, 2020
1f94849
Added test for error with incorrect out=
IvanYashchuk Oct 20, 2020
9bd2f3f
Merge remote-tracking branch 'upstream/master' into linalg-kron
IvanYashchuk Oct 20, 2020
5e2f5e8
Added tensor_docs entry
IvanYashchuk Oct 20, 2020
71ceaab
tensor2 -> other
IvanYashchuk Oct 22, 2020
7786e4c
Updated documentation
IvanYashchuk Oct 22, 2020
50cb9be
Removed returns block in docs
IvanYashchuk Oct 22, 2020
15985b3
Added dispatch: Math
IvanYashchuk Oct 22, 2020
3f2e70d
Use at::native::resize_output in kron_out
IvanYashchuk Oct 22, 2020
a980bb4
Added non-contiguous test cases
IvanYashchuk Oct 22, 2020
b45d6c8
Merge remote-tracking branch 'upstream/master' into linalg-kron
IvanYashchuk Oct 24, 2020
e22a079
Use bmatrix instead of array in matrix example
IvanYashchuk Oct 24, 2020
fa90e71
Fixed non-contiguous tests
IvanYashchuk Oct 24, 2020
5ad2f45
Updated comment in LinearAlgebra.cpp
IvanYashchuk Oct 26, 2020
1cf37ff
Updated docs for general n-dim tensors
IvanYashchuk Oct 26, 2020
60db523
Changed 'block tensor' -> 'tensor'
IvanYashchuk Oct 26, 2020
5888b0e
Fix too long line
IvanYashchuk Oct 26, 2020
719a687
Changed the note according to Mike's suggestion
IvanYashchuk Nov 2, 2020
a1b3255
Merge remote-tracking branch 'upstream/master' into linalg-kron
IvanYashchuk Nov 2, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
66 changes: 66 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -1738,5 +1738,71 @@ Tensor chain_matmul(TensorList matrices) {
}
}

/*
Calculates the Kronecker product between two Tensors.
*/
Tensor kron(const Tensor& self, const Tensor& other) {
/*
We can obtain the kron result using tensordot or einsum. The implementation below uses tensordot.
In einsum notation suppose we have `self` with dim 4 and `other` with dim 2
the result of below tensordot is in einsum 0123, 45 -> 012345.
To obtain the correct kron we need to permute and reshape the array.
The permutation rule is the following: going from right to left
take axes in turn to form the permutation
with our example the correct permutation is 012435 and
the kron shape is (shape_self[0], shape_self[1], shape_self[3]*shape_other[0],
shape_self[4]*shape_other[1])
*/
std::vector<int64_t> self_sizes = self.sizes().vec();
std::vector<int64_t> other_sizes = other.sizes().vec();
int64_t self_ndim = self.dim();
int64_t other_ndim = other.dim();
int64_t min_ndim = std::min(self_ndim, other_ndim);
int64_t ndim_diff = std::abs(self_ndim - other_ndim);

std::vector<int64_t> a_axes(self_ndim);
std::vector<int64_t> b_axes(other_ndim);
std::iota(a_axes.begin(), a_axes.end(), 0);
std::iota(b_axes.begin(), b_axes.end(), 0 + self_ndim);

bool is_a_larger = self_ndim >= other_ndim;
std::vector<int64_t> kron_permutation(self_ndim + other_ndim);
for (int64_t i = 0; i < ndim_diff; i++) {
kron_permutation[i] = is_a_larger ? a_axes[i] : b_axes[i];
}
for (int64_t i = 0, j = 0; i < min_ndim; i++, j += 2) {
kron_permutation[self_ndim + other_ndim - 1 - j] = b_axes[other_ndim - 1 - i];
kron_permutation[self_ndim + other_ndim - 1 - j - 1] = a_axes[self_ndim - 1 - i];
}

std::vector<int64_t> result_shape(std::max(self_ndim, other_ndim));
for (int64_t i = 0; i < ndim_diff; i++) {
result_shape[i] = is_a_larger ? self_sizes[i] : other_sizes[i];
}
for (int64_t i = 0; i < min_ndim; i++) {
result_shape[ndim_diff + i] = is_a_larger
? self_sizes[ndim_diff + i] * other_sizes[i]
: other_sizes[ndim_diff + i] * self_sizes[i];
}

Tensor result = at::tensordot(self, other, {}, {});
// Step 2: now permute result
result = result.permute(kron_permutation);
// Step 3: reshape
result = result.reshape(result_shape);

return result;
}

Tensor& kron_out(Tensor& result, const Tensor& self, const Tensor& other) {
TORCH_CHECK(result.scalar_type() == self.scalar_type(),
"result dtype ", result.scalar_type(), " does not match self dtype ", self.scalar_type());

Tensor result_tmp = at::kron(self, other);
at::native::resize_output(result, result_tmp.sizes());
result.copy_(result_tmp);
return result;
}

} // namespace native
} // namespace at
9 changes: 9 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -2142,6 +2142,15 @@
CPU: kl_div_backward_cpu
CUDA: kl_div_backward_cuda

- func: kron(Tensor self, Tensor other) -> Tensor
variants: function, method
dispatch:
Math: kron

IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
- func: kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
dispatch:
Math: kron_out

- func: kthvalue(Tensor self, int k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)
use_c10_dispatcher: full
variants: function, method
Expand Down
1 change: 1 addition & 0 deletions docs/source/torch.rst
Expand Up @@ -459,6 +459,7 @@ Other Operations
flip
fliplr
flipud
kron
rot90
gcd
histc
Expand Down
101 changes: 101 additions & 0 deletions test/test_linalg.py
Expand Up @@ -183,6 +183,107 @@ def test_det(self, device, dtype):
with self.assertRaises(RuntimeError):
op(t)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron(self, device, dtype):

def run_test_case(a_shape, b_shape):
a = torch.rand(a_shape, dtype=dtype, device=device)
b = torch.rand(b_shape, dtype=dtype, device=device)

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty_like(result)
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

shapes = [(4,), (2, 2), (1, 2, 3), (1, 2, 3, 3)]
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
run_test_case(a_shape, b_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_non_contiguous(self, device, dtype):

def run_test_transposed(a_shape, b_shape):
# check for transposed case
a = torch.rand(a_shape, dtype=dtype, device=device).transpose(-2, -1)
b = torch.rand(b_shape, dtype=dtype, device=device).transpose(-2, -1)
self.assertFalse(a.is_contiguous())
self.assertFalse(b.is_contiguous())

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty(result.transpose(-2, -1).shape, dtype=dtype, device=device).transpose(-2, -1)
self.assertFalse(out.is_contiguous())
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

def run_test_skipped_elements(a_shape, b_shape):
# check for transposed case
a = torch.rand(2 * a_shape[0], *a_shape[1:], dtype=dtype, device=device)[::2]
b = torch.rand(2 * b_shape[0], *b_shape[1:], dtype=dtype, device=device)[::2]
self.assertFalse(a.is_contiguous())
self.assertFalse(b.is_contiguous())

expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
result = torch.kron(a, b)
self.assertEqual(result, expected)

# check the out= variant
out = torch.empty(2 * result.shape[0], *result.shape[1:], dtype=dtype, device=device)[::2]
self.assertFalse(out.is_contiguous())
ans = torch.kron(a, b, out=out)
self.assertEqual(ans, out)
self.assertEqual(ans, result)

shapes = [(2, 2), (2, 2, 3), (2, 2, 3, 3)]
for a_shape, b_shape in itertools.product(shapes, reversed(shapes)):
# run_test_transposed(a_shape, b_shape)
run_test_skipped_elements(a_shape, b_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_empty(self, device, dtype):

def run_test_case(empty_shape):
a = torch.eye(3, dtype=dtype, device=device)
b = torch.empty(empty_shape, dtype=dtype, device=device)
result = torch.kron(a, b)
expected = np.kron(a.cpu().numpy(), b.cpu().numpy())
self.assertEqual(result, expected)

# NumPy doesn't work if the first argument is empty
result = torch.kron(b, a)
self.assertEqual(result.shape, expected.shape)

empty_shapes = [(0,), (2, 0), (1, 0, 3)]
for empty_shape in empty_shapes:
run_test_case(empty_shape)

@dtypes(torch.float32, torch.float64, torch.complex64, torch.complex128)
def test_kron_errors_and_warnings(self, device, dtype):
# if non-empty out tensor with wrong shape is passed a warning is given
a = torch.eye(3, dtype=dtype, device=device)
b = torch.ones((2, 2), dtype=dtype, device=device)
out = torch.empty_like(a)
with warnings.catch_warnings(record=True) as w:
# Trigger warning
torch.kron(a, b, out=out)
# Check warning occurs
self.assertEqual(len(w), 1)
self.assertTrue("An output with one or more elements was resized" in str(w[-1].message))

# dtypes should match
out = torch.empty_like(a).to(torch.int)
with self.assertRaisesRegex(RuntimeError, "result dtype Int does not match self dtype"):
torch.kron(a, b, out=out)

# This test confirms that torch.linalg.norm's dtype argument works
# as expected, according to the function's documentation
@skipCUDAIfNoMagma
Expand Down
7 changes: 7 additions & 0 deletions torch/_tensor_docs.py
Expand Up @@ -1956,6 +1956,13 @@ def add_docstr_all(method, docstr):

""")

add_docstr_all('kron',
r"""
kron(other) -> Tensor

See :func:`torch.kron`
""")

add_docstr_all('kthvalue',
r"""
kthvalue(k, dim=None, keepdim=False) -> (Tensor, LongTensor)
Expand Down
58 changes: 58 additions & 0 deletions torch/_torch_docs.py
Expand Up @@ -3669,6 +3669,64 @@ def merge_dicts(*dicts):
RuntimeError: bool value of Tensor with no values is ambiguous
""".format(**common_args))

add_docstr(torch.kron,
r"""
kron(input, other, *, out=None) -> Tensor

Computes the Kronecker product, denoted by :math:`\otimes`, of :attr:`input` and :attr:`other`.

If :attr:`input` is a :math:`(a_0 \times a_1 \times \dots \times a_n)` tensor and :attr:`other` is a
:math:`(b_0 \times b_1 \times \dots \times b_n)` tensor, the result will be a
:math:`(a_0*b_0 \times a_1*b_1 \times \dots \times a_n*b_n)` tensor with the following entries:

.. math::
(\text{input} \otimes \text{other})_{k_0, k_1, \dots, k_n} =
\text{input}_{i_0, i_1, \dots, i_n} * \text{other}_{j_0, j_1, \dots, j_n},

where :math:`k_t = i_t * b_t + j_t` for :math:`0 \leq t \leq n`.
If one tensor has fewer dimensions than the other it is unsqueezed until it has the same number of dimensions.

Supports real-valued and complex-valued inputs.

.. note::
This function generalizes the typical definition of the Kronecker product for two matrices to two tensors,
as described above. When :attr:`input` is a :math:`(m \times n)` matrix and :attr:`other` is a
:math:`(p \times q)` matrix, the result will be a :math:`(p*m \times q*n)` block matrix:

.. math::
\mathbf{A} \otimes \mathbf{B}=\begin{bmatrix}
a_{11} \mathbf{B} & \cdots & a_{1 n} \mathbf{B} \\
\vdots & \ddots & \vdots \\
a_{m 1} \mathbf{B} & \cdots & a_{m n} \mathbf{B} \end{bmatrix}

where :attr:`input` is :math:`\mathbf{A}` and :attr:`other` is :math:`\mathbf{B}`.

Arguments:
input (Tensor)
other (Tensor)

Keyword args:
out (Tensor, optional): The output tensor. Ignored if ``None``. Default: ``None``

Examples::

>>> mat1 = torch.eye(2)
>>> mat2 = torch.ones(2, 2)
IvanYashchuk marked this conversation as resolved.
Show resolved Hide resolved
>>> torch.kron(mat1, mat2)
tensor([[1., 1., 0., 0.],
[1., 1., 0., 0.],
[0., 0., 1., 1.],
[0., 0., 1., 1.]])

>>> mat1 = torch.eye(2)
>>> mat2 = torch.arange(1, 5).reshape(2, 2)
>>> torch.kron(mat1, mat2)
tensor([[1., 2., 0., 0.],
[3., 4., 0., 0.],
[0., 0., 1., 2.],
[0., 0., 3., 4.]])
""")

add_docstr(torch.kthvalue,
r"""
kthvalue(input, k, dim=None, keepdim=False, *, out=None) -> (Tensor, LongTensor)
Expand Down
1 change: 1 addition & 0 deletions torch/overrides.py
Expand Up @@ -417,6 +417,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.istft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
normalized=False, onesided=None, length=None, return_complex=False: -1),
torch.kl_div: lambda input, target, size_average=None, reduce=None, reduction='mean', log_target=False: -1,
torch.kron: lambda input, other: -1,
torch.kthvalue: lambda input, k, dim=None, keepdim=False, out=None: -1,
torch.layer_norm: lambda input, normalized_shape, weight=None, bias=None, esp=1e-05, cudnn_enabled=True: -1,
torch.lcm: lambda input, other, out=None: -1,
Expand Down
1 change: 1 addition & 0 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -1460,6 +1460,7 @@ def method_tests():
('__getitem__', torch.randn(S, S, S), (dont_convert([[0, 2, 3], [1, 3, 3],
torch.LongTensor([0, 0, 2])]),), 'adv_index_var'),
('to_sparse', (S, S), (), '', (), (), [], lambda x: x.to_dense()),
('kron', (S, S), ((M, L),))
]

def create_input(call_args, requires_grad=True, non_contiguous=False, call_kwargs=None, dtype=torch.double, device=None):
Expand Down