Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot #53453

Closed
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
194e9ae
Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot
heitorschueroff Mar 6, 2021
025f1f2
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 6, 2021
0d890d4
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 9, 2021
d8f0708
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 11, 2021
8b0889f
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 12, 2021
146508f
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 17, 2021
6dd047a
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
67f31af
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
e4164d5
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
8abb119
Update test/test_linalg.py
heitorschueroff Mar 29, 2021
c71d440
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
f6381eb
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
c0d74fd
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
fc36b1f
added wrapper function to pack inputs to multi_dot in distributed tes…
heitorschueroff Mar 30, 2021
5b49fee
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
88c4944
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
14670b4
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 31, 2021
3684278
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -227,7 +227,6 @@ _(aten, cat) \
_(aten, cauchy) \
_(aten, ceil) \
_(aten, celu) \
_(aten, chain_matmul) \
_(aten, cholesky) \
_(aten, cholesky_inverse) \
_(aten, cholesky_solve) \
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -200,6 +200,8 @@ namespace c10 {
_(aten, linalg_det) \
_(aten, matrix_power) \
_(aten, linalg_matrix_power) \
_(aten, chain_matmul) \
_(aten, linalg_multi_dot) \
_(aten, linalg_norm) \
_(aten, linalg_vector_norm) \
_(aten, append) \
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -617,6 +617,20 @@ Tensor chain_matmul(TensorList matrices) {
return at::native::linalg_multi_dot(matrices);
}

Tensor& chain_matmul_out(TensorList matrices, Tensor& result) {
checkAllSameDim(matrices, 2);

TORCH_CHECK(
matrices.size() > 0, "chain_matmul(): Expected one or more matrices");

if (matrices.size() == 1) {
at::native::resize_output(result, matrices[0].sizes());
return result.copy_(matrices[0]);
}

return at::native::linalg_multi_dot_out(matrices, result);
}

static void check_1d(const Tensor& t, const char* arg, const char* fn) {
TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -881,9 +881,13 @@
dispatch:
CPU, CUDA: ceil_out

# alias for torch.linalg.multi_dot
- func: chain_matmul(Tensor[] matrices) -> Tensor
variants: function

# alias for torch.linalg.multi_dot
- func: chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)

- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
variants: function, method
device_guard: False
Expand Down
14 changes: 0 additions & 14 deletions test/test_autograd.py
Expand Up @@ -3021,20 +3021,6 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

def test_chain_matmul(self):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
def gen_matrices(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1).requires_grad_())
return matrices

gradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))
gradgradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))

def test_profiler_tracing(self):
t1, t2 = torch.ones(1), torch.ones(1)
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
Expand Down
36 changes: 17 additions & 19 deletions test/test_linalg.py
Expand Up @@ -3658,6 +3658,21 @@ def test_old_matrix_rank(self, device, dtype):
self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True))
self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))

@onlyOnCPUAndCUDA
@dtypes(torch.double)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
# This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for.
def test_chain_matmul(self, device, dtype):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
# chain_matmul accepts a single input tensor while multi_dot does not
t = make_tensor((2, 2), device, dtype)
self.assertEqual(t, torch.chain_matmul(t))
with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

# chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
# be either 1D or 2D
with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
torch.chain_matmul(make_tensor(1, device, dtype), make_tensor(1, device, dtype))

@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.cdouble)
def test_multi_dot(self, device, dtype):
Expand Down Expand Up @@ -3691,9 +3706,11 @@ def check(*shapes, discontiguous=False):

# test large tensors
check([10, 100], [100, 5], [5, 50])
check([10, 20], [20, 30], [30, 5])

# test discontiguous input
check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
check([15, 5], [5, 10], [10, 20], [20, 25], discontiguous=True)

@onlyOnCPUAndCUDA
@dtypes(torch.float)
Expand Down Expand Up @@ -6188,25 +6205,6 @@ def run_test(*n):
run_test(3, 3, 4, 4)
run_test(3, 3, 5, 5)

@dtypes(torch.double)
def test_chain_matmul(self, device, dtype):
def product(matrices):
for mat in matrices[1:]:
matrices[0] = matrices[0].mm(mat)
return matrices[0]

def run_test(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device))
self.assertEqual(torch.chain_matmul(*matrices), product(matrices))

run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])

with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

@skipCUDAIfNoMagma
@skipCUDAIfRocm
@skipCPUIfNoLapack
Expand Down
6 changes: 4 additions & 2 deletions torch/functional.py
Expand Up @@ -1470,17 +1470,19 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
else:
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore

def chain_matmul(*matrices):
def chain_matmul(*matrices, out=None):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.

.. warning::
:func:`torch.chain_matmul` is deprecated, use :func:`torch.linalg.multi_dot` instead.

Args:
matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.

out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.

Returns:
Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Expand Up @@ -316,7 +316,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alhpa=1., inplace=False: -1,
torch.chain_matmul: lambda *matrices: -1,
torch.chain_matmul: lambda *matrices, out=None: -1,
torch.channel_shuffle: lambda input, groups : -1,
torch.cholesky: lambda input, upper=False, out=None: -1,
torch.linalg.cholesky: lambda input, out=None: -1,
Expand Down