Skip to content

Commit

Permalink
Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot
Browse files Browse the repository at this point in the history
ghstack-source-id: b5a98a2ece6e11b3978c71f0ca992b0a3763d826
Pull Request resolved: #53453
  • Loading branch information
heitorschueroff committed Mar 31, 2021
1 parent 3bdfb7a commit 9187eba
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 38 deletions.
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -227,7 +227,6 @@ _(aten, cat) \
_(aten, cauchy) \
_(aten, ceil) \
_(aten, celu) \
_(aten, chain_matmul) \
_(aten, cholesky) \
_(aten, cholesky_inverse) \
_(aten, cholesky_solve) \
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -200,6 +200,8 @@ namespace c10 {
_(aten, linalg_det) \
_(aten, matrix_power) \
_(aten, linalg_matrix_power) \
_(aten, chain_matmul) \
_(aten, linalg_multi_dot) \
_(aten, linalg_norm) \
_(aten, linalg_vector_norm) \
_(aten, append) \
Expand Down
14 changes: 14 additions & 0 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -617,6 +617,20 @@ Tensor chain_matmul(TensorList matrices) {
return at::native::linalg_multi_dot(matrices);
}

Tensor& chain_matmul_out(TensorList matrices, Tensor& result) {
checkAllSameDim(matrices, 2);

TORCH_CHECK(
matrices.size() > 0, "chain_matmul(): Expected one or more matrices");

if (matrices.size() == 1) {
at::native::resize_output(result, matrices[0].sizes());
return result.copy_(matrices[0]);
}

return at::native::linalg_multi_dot_out(matrices, result);
}

static void check_1d(const Tensor& t, const char* arg, const char* fn) {
TORCH_CHECK(t.dim() == 1, fn, ": Expected 1-D argument ", arg, ", but got ", t.dim(), "-D");
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -881,9 +881,13 @@
dispatch:
CPU, CUDA: ceil_out

# alias for torch.linalg.multi_dot
- func: chain_matmul(Tensor[] matrices) -> Tensor
variants: function

# alias for torch.linalg.multi_dot
- func: chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)

- func: unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]
variants: function, method
device_guard: False
Expand Down
15 changes: 0 additions & 15 deletions test/test_autograd.py
Expand Up @@ -3021,21 +3021,6 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

def test_chain_matmul(self):
def gen_matrices(p, dtype):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype).requires_grad_())
return matrices

for dtype in [torch.double, torch.cdouble]:
gradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5], dtype))
gradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6], dtype))
gradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10], dtype))
gradgradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5], dtype))
gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6], dtype))
gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10], dtype))

def test_profiler_tracing(self):
t1, t2 = torch.ones(1), torch.ones(1)
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
Expand Down
36 changes: 17 additions & 19 deletions test/test_linalg.py
Expand Up @@ -3662,6 +3662,21 @@ def test_old_matrix_rank(self, device, dtype):
self.assertEqual(torch.matrix_rank(aaT, True), np.linalg.matrix_rank(aaT.cpu().numpy(), True))
self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))

@onlyOnCPUAndCUDA
@dtypes(torch.double)
# This tests only the cases where torch.chain_matmul differs from torch.linalg.multi_dot which this is an "alias" for.
def test_chain_matmul(self, device, dtype):
# chain_matmul accepts a single input tensor while multi_dot does not
t = make_tensor((2, 2), device, dtype)
self.assertEqual(t, torch.chain_matmul(t))
with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

# chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
# be either 1D or 2D
with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
torch.chain_matmul(make_tensor(1, device, dtype), make_tensor(1, device, dtype))

@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.cdouble)
def test_multi_dot(self, device, dtype):
Expand Down Expand Up @@ -3695,9 +3710,11 @@ def check(*shapes, discontiguous=False):

# test large tensors
check([10, 100], [100, 5], [5, 50])
check([10, 20], [20, 30], [30, 5])

# test discontiguous input
check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
check([15, 5], [5, 10], [10, 20], [20, 25], discontiguous=True)

@onlyOnCPUAndCUDA
@dtypes(torch.float)
Expand Down Expand Up @@ -6192,25 +6209,6 @@ def run_test(*n):
run_test(3, 3, 4, 4)
run_test(3, 3, 5, 5)

@dtypes(torch.double, torch.cdouble)
def test_chain_matmul(self, device, dtype):
def product(matrices):
for mat in matrices[1:]:
matrices[0] = matrices[0].mm(mat)
return matrices[0]

def run_test(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device))
self.assertEqual(torch.chain_matmul(*matrices), product(matrices))

run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])

with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

@skipCUDAIfNoMagma
@skipCUDAIfRocm
@skipCPUIfNoLapack
Expand Down
6 changes: 4 additions & 2 deletions torch/functional.py
Expand Up @@ -1470,17 +1470,19 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): # noqa
else:
return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore

def chain_matmul(*matrices):
def chain_matmul(*matrices, out=None):
r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
.. warning::
:func:`torch.chain_matmul` is deprecated, use :func:`torch.linalg.multi_dot` instead.
Args:
matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
Returns:
Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
Expand Down
2 changes: 1 addition & 1 deletion torch/overrides.py
Expand Up @@ -317,7 +317,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.cdist: lambda x1, x2, p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary': -1,
torch.ceil: lambda input, out=None: -1,
torch.celu: lambda input, alhpa=1., inplace=False: -1,
torch.chain_matmul: lambda *matrices: -1,
torch.chain_matmul: lambda *matrices, out=None: -1,
torch.channel_shuffle: lambda input, groups : -1,
torch.cholesky: lambda input, upper=False, out=None: -1,
torch.linalg.cholesky: lambda input, out=None: -1,
Expand Down

0 comments on commit 9187eba

Please sign in to comment.