Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot #53453

Closed
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
194e9ae
Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot
heitorschueroff Mar 6, 2021
025f1f2
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 6, 2021
0d890d4
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 9, 2021
d8f0708
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 11, 2021
8b0889f
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 12, 2021
146508f
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 17, 2021
6dd047a
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
67f31af
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
e4164d5
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
8abb119
Update test/test_linalg.py
heitorschueroff Mar 29, 2021
c71d440
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
f6381eb
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 29, 2021
c0d74fd
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
fc36b1f
added wrapper function to pack inputs to multi_dot in distributed tes…
heitorschueroff Mar 30, 2021
5b49fee
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
88c4944
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 30, 2021
14670b4
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 31, 2021
3684278
Update on "Deprecate torch.chain_matmul in favor of torch.linalg.mult…
heitorschueroff Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -281,7 +281,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16)
KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
// The macro doesn't like these (I think it chokes on commas inside <>) so write them manually
m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"),
TORCH_FN((&WrapFunction<CastPolicy::fp16,
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -228,7 +228,6 @@ _(aten, cat) \
_(aten, cauchy) \
_(aten, ceil) \
_(aten, celu) \
_(aten, chain_matmul) \
_(aten, cholesky) \
_(aten, cholesky_inverse) \
_(aten, cholesky_solve) \
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -196,6 +196,8 @@ namespace c10 {
_(aten, clip_) \
_(aten, det) \
_(aten, linalg_det) \
_(aten, chain_matmul) \
_(aten, linalg_multi_dot) \
_(aten, linalg_norm) \
_(aten, append) \
_(aten, item) \
Expand Down
1 change: 1 addition & 0 deletions docs/source/amp.rst
Expand Up @@ -104,6 +104,7 @@ Ops that can autocast to ``float16``
``baddbmm``,
``bmm``,
``chain_matmul``,
``multi_dot``,
``conv1d``,
``conv2d``,
``conv3d``,
Expand Down
14 changes: 0 additions & 14 deletions test/test_autograd.py
Expand Up @@ -2936,20 +2936,6 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

def test_chain_matmul(self):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
def gen_matrices(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1).requires_grad_())
return matrices

gradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))
gradgradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))

def test_profiler_tracing(self):
t1, t2 = torch.ones(1), torch.ones(1)
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
Expand Down
6 changes: 6 additions & 0 deletions test/test_cuda.py
Expand Up @@ -2674,6 +2674,12 @@ def test_autocast_nn_fp32(self):
for op, args in self.autocast_lists.nn_fp32:
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)

@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_linalg_fp16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
for op, args in self.autocast_lists.linalg_fp16:
self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._linalg)

@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_methods_fp16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
Expand Down
38 changes: 17 additions & 21 deletions test/test_linalg.py
Expand Up @@ -3274,8 +3274,21 @@ def test_old_matrix_rank(self, device, dtype):
self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))

@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.cfloat)
@precisionOverride({torch.float: 1e-02, torch.cfloat: 1e-02})
@dtypes(torch.double)
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
def test_chain_matmul(self, device, dtype):
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
# chain_matmul accepts a single input tensor while multi_dot does not
t = make_tensor((2, 2), device, dtype)
self.assertEqual(t, torch.chain_matmul(t))
with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

# chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
# be either 1D or 2D
with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
torch.chain_matmul(make_tensor(1, device, dtype), make_tensor(1, device, dtype))

@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.cdouble)
def test_multi_dot(self, device, dtype):
def check(*shapes, discontiguous=False):
tensors = [make_tensor(shape, device, dtype, discontiguous=discontiguous) for shape in shapes]
Expand Down Expand Up @@ -3307,9 +3320,11 @@ def check(*shapes, discontiguous=False):

# test large tensors
check([10, 100], [100, 5], [5, 50])
check([10, 20], [20, 30], [30, 5])

# test discontiguous input
check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
check([15, 5], [5, 10], [10, 20], [20, 25], discontiguous=True)

@onlyOnCPUAndCUDA
@dtypes(torch.float)
Expand Down Expand Up @@ -5800,25 +5815,6 @@ def run_test(*n):
run_test(3, 3, 4, 4)
run_test(3, 3, 5, 5)

@dtypes(torch.double)
def test_chain_matmul(self, device, dtype):
def product(matrices):
for mat in matrices[1:]:
matrices[0] = matrices[0].mm(mat)
return matrices[0]

def run_test(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device))
self.assertEqual(torch.chain_matmul(*matrices), product(matrices))

run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])

with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

@skipCUDAIfNoMagma
@skipCUDAIfRocm
@skipCPUIfNoLapack
Expand Down
4 changes: 4 additions & 0 deletions torch/functional.py
Expand Up @@ -1438,6 +1438,10 @@ def chain_matmul(*matrices):
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.

.. note::
:func:`torch.chain_matmul` is deprecated, use :func:`torch.linalg.multi_dot` instead. The differences are
heitorschueroff marked this conversation as resolved.
Show resolved Hide resolved
that :func:`torch.linalg.multi_dot` requires at least 2 input tensors and allows the first and last
tensor to be either 1D or 2D.

Args:
matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/autocast_test_lists.py
Expand Up @@ -109,7 +109,6 @@ def __init__(self, dev):
("matmul", mat0_fp32 + mat1_fp32),
("mm", mat0_fp32 + mat1_fp32),
("mv", mat0_fp32 + pointwise0_fp32),
("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
Expand Down Expand Up @@ -221,6 +220,9 @@ def __init__(self, dev):
("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
self.linalg_fp16 = [
("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
]
self.methods_fp16 = [
("__matmul__", mat0_fp32 + mat1_fp32)
]
Expand Down
20 changes: 12 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -390,18 +390,22 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
kwargs=dict(dim=1)),)

def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad):
test_cases: List[List[Tuple[int, ...]]] = [
[(S,), (S,)],
[(1, S), (S, 1)],
[(S, 0), (0, S)],
[(0, S), (S, 2)],
[(2, S), (S, 2), (2, S)],
# Each test case consists of the sizes in the chain of multiplications
# e.g. [2, S, 2, S] generates matrices (2, S) @ (S, 2) @ (2, S)
test_cases = [
Copy link
Collaborator

@mruberry mruberry Mar 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How much does this increase test time by? Is the net test time decreased or increased with the chain_matmul() test removal?

[1, S, 1],
[S, 0, S],
[0, S, 2],
[2, S, 2, S],
[2, 3, 4, 5],
[5, 4, 0, 2],
[2, 4, 3, 5, 3, 2]
]

result = []
for test_case in test_cases:
for sizes in test_cases:
tensors = []
for size in test_case:
for size in zip(sizes[:-1], sizes[1:]):
t = make_tensor(size, device, dtype, requires_grad=requires_grad)
tensors.append(t)
result.append(SampleInput(tensors))
Expand Down