Skip to content

Commit

Permalink
Deprecate torch.chain_matmul in favor of torch.linalg.multi_dot
Browse files Browse the repository at this point in the history
ghstack-source-id: 87b18c7ea86acd38e6fea0bc1435aba3c10c9122
Pull Request resolved: #53453
  • Loading branch information
heitorschueroff committed Mar 11, 2021
1 parent 917f80f commit 8f9dc98
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 46 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/autocast_mode.cpp
Expand Up @@ -281,7 +281,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(addbmm), "addbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
KERNEL(ADD_NS(baddbmm), "baddbmm", Tensor (const Tensor &, const Tensor &, const Tensor &, Scalar, Scalar), fp16)
KERNEL(ADD_NS(bmm), "bmm", Tensor (const Tensor &, const Tensor &), fp16)
KERNEL(ADD_NS(chain_matmul), "chain_matmul", Tensor (TensorList), fp16)
KERNEL(ADD_NS(linalg_multi_dot), "linalg_multi_dot", Tensor (TensorList), fp16)
// The macro doesn't like these (I think it chokes on commas inside <>) so write them manually
m.impl(TORCH_SELECTIVE_NAME("aten::_thnn_fused_lstm_cell"),
TORCH_FN((&WrapFunction<CastPolicy::fp16,
Expand Down
1 change: 0 additions & 1 deletion aten/src/ATen/core/aten_interned_strings.h
Expand Up @@ -228,7 +228,6 @@ _(aten, cat) \
_(aten, cauchy) \
_(aten, ceil) \
_(aten, celu) \
_(aten, chain_matmul) \
_(aten, cholesky) \
_(aten, cholesky_inverse) \
_(aten, cholesky_solve) \
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/core/interned_strings.h
Expand Up @@ -196,6 +196,8 @@ namespace c10 {
_(aten, clip_) \
_(aten, det) \
_(aten, linalg_det) \
_(aten, chain_matmul) \
_(aten, linalg_multi_dot) \
_(aten, linalg_norm) \
_(aten, append) \
_(aten, item) \
Expand Down
1 change: 1 addition & 0 deletions docs/source/amp.rst
Expand Up @@ -104,6 +104,7 @@ Ops that can autocast to ``float16``
``baddbmm``,
``bmm``,
``chain_matmul``,
``multi_dot``,
``conv1d``,
``conv2d``,
``conv3d``,
Expand Down
14 changes: 0 additions & 14 deletions test/test_autograd.py
Expand Up @@ -2936,20 +2936,6 @@ def test_igammac(self):
gradcheck(torch.igamma, (s, x))
gradgradcheck(torch.igamma, (s, x))

def test_chain_matmul(self):
def gen_matrices(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1).requires_grad_())
return matrices

gradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))
gradgradcheck(torch.chain_matmul, gen_matrices([5, 10, 15, 5]))
gradgradcheck(torch.chain_matmul, gen_matrices([3, 5, 2, 6]))
gradgradcheck(torch.chain_matmul, gen_matrices([6, 2, 4, 8, 10]))

def test_profiler_tracing(self):
t1, t2 = torch.ones(1), torch.ones(1)
with torch.autograd.profiler.profile(use_kineto=kineto_available()) as prof:
Expand Down
6 changes: 6 additions & 0 deletions test/test_cuda.py
Expand Up @@ -2666,6 +2666,12 @@ def test_autocast_nn_fp32(self):
for op, args in self.autocast_lists.nn_fp32:
self._run_autocast_outofplace(op, args, torch.float32, module=torch._C._nn)

@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_linalg_fp16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
for op, args in self.autocast_lists.linalg_fp16:
self._run_autocast_outofplace(op, args, torch.float16, module=torch._C._linalg)

@unittest.skipIf(not TEST_CUDNN, 'CUDNN not available')
def test_autocast_methods_fp16(self):
with torch.backends.cudnn.flags(enabled=True, deterministic=True):
Expand Down
38 changes: 17 additions & 21 deletions test/test_linalg.py
Expand Up @@ -3274,8 +3274,21 @@ def test_old_matrix_rank(self, device, dtype):
self.assertEqual(torch.matrix_rank(aaT, 0.01, True), np.linalg.matrix_rank(aaT.cpu().numpy(), 0.01, True))

@onlyOnCPUAndCUDA
@dtypes(torch.float, torch.cfloat)
@precisionOverride({torch.float: 1e-02, torch.cfloat: 1e-02})
@dtypes(torch.double)
def test_chain_matmul(self, device, dtype):
# chain_matmul accepts a single input tensor while multi_dot does not
t = make_tensor((2, 2), device, dtype)
self.assertEqual(t, torch.chain_matmul(t))
with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

# chain_matmul expects all tensors to be 2D whereas multi_dot allows the first and last tensors to
# be either 1D or 2D
with self.assertRaisesRegex(RuntimeError, r"Tensor dimension is 1, expected 2 instead"):
torch.chain_matmul(make_tensor(1, device, dtype), make_tensor(1, device, dtype))

@onlyOnCPUAndCUDA
@dtypes(torch.double, torch.cdouble)
def test_multi_dot(self, device, dtype):
def check(*shapes, discontiguous=False):
tensors = [make_tensor(shape, device, dtype, discontiguous=discontiguous) for shape in shapes]
Expand Down Expand Up @@ -3307,9 +3320,11 @@ def check(*shapes, discontiguous=False):

# test large tensors
check([10, 100], [100, 5], [5, 50])
check([10, 20], [20, 30], [30, 5])

# test discontiguous input
check([3, 2], [2, 2], [2, 3], [3, 4], discontiguous=True)
check([15, 5], [5, 10], [10, 20], [20, 25], discontiguous=True)

@onlyOnCPUAndCUDA
@dtypes(torch.float)
Expand Down Expand Up @@ -5800,25 +5815,6 @@ def run_test(*n):
run_test(3, 3, 4, 4)
run_test(3, 3, 5, 5)

@dtypes(torch.double)
def test_chain_matmul(self, device, dtype):
def product(matrices):
for mat in matrices[1:]:
matrices[0] = matrices[0].mm(mat)
return matrices[0]

def run_test(p):
matrices = []
for (pi, pi_1) in zip(p[:-1], p[1:]):
matrices.append(torch.randn(pi, pi_1, dtype=dtype, device=device))
self.assertEqual(torch.chain_matmul(*matrices), product(matrices))

run_test([10, 20, 30, 5])
run_test([15, 5, 10, 20, 25])

with self.assertRaisesRegex(RuntimeError, r"chain_matmul\(\): Expected one or more matrices"):
torch.chain_matmul()

@skipCUDAIfNoMagma
@skipCUDAIfRocm
@skipCPUIfNoLapack
Expand Down
4 changes: 4 additions & 0 deletions torch/functional.py
Expand Up @@ -1438,6 +1438,10 @@ def chain_matmul(*matrices):
needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
.. note::
:func:`torch.chain_matmul` is deprecated, use :func:`torch.linalg.multi_dot` instead. The differences are
that :func:`torch.linalg.multi_dot` requires at least 2 input tensors and allows the first and last
tensor to be either 1D or 2D.
Args:
matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/autocast_test_lists.py
Expand Up @@ -109,7 +109,6 @@ def __init__(self, dev):
("matmul", mat0_fp32 + mat1_fp32),
("mm", mat0_fp32 + mat1_fp32),
("mv", mat0_fp32 + pointwise0_fp32),
("chain_matmul", mat0_fp32 + mat1_fp32 + mat2_fp32),
("addbmm", mat0_fp32 + (torch.randn((n, n, n), device=dev, dtype=torch.float32),
torch.randn((n, n, n), device=dev, dtype=torch.float32))),
("baddbmm", (torch.randn((n, n, n), device=dev, dtype=torch.float32),
Expand Down Expand Up @@ -221,6 +220,9 @@ def __init__(self, dev):
("soft_margin_loss", mat0_fp16 + (torch.ones((n, n), device=dev, dtype=torch.long),)),
("multi_margin_loss", mat0_fp16 + (torch.ones((n,), device=dev, dtype=torch.long),)),
]
self.linalg_fp16 = [
("linalg_multi_dot", (mat0_fp32 + mat1_fp32 + mat2_fp32,)),
]
self.methods_fp16 = [
("__matmul__", mat0_fp32 + mat1_fp32)
]
Expand Down
20 changes: 12 additions & 8 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -390,18 +390,22 @@ def sample_inputs_tensor_split(op_info, device, dtype, requires_grad):
kwargs=dict(dim=1)),)

def sample_inputs_linalg_multi_dot(op_info, device, dtype, requires_grad):
test_cases: List[List[Tuple[int, ...]]] = [
[(S,), (S,)],
[(1, S), (S, 1)],
[(S, 0), (0, S)],
[(0, S), (S, 2)],
[(2, S), (S, 2), (2, S)],
# Each test case consists of the sizes in the chain of multiplications
# e.g. [2, S, 2, S] generates matrices (2, S) @ (S, 2) @ (2, S)
test_cases = [
[1, S, 1],
[S, 0, S],
[0, S, 2],
[2, S, 2, S],
[2, 3, 4, 5],
[5, 4, 0, 2],
[2, 4, 3, 5, 3, 2]
]

result = []
for test_case in test_cases:
for sizes in test_cases:
tensors = []
for size in test_case:
for size in zip(sizes[:-1], sizes[1:]):
t = make_tensor(size, device, dtype, requires_grad=requires_grad)
tensors.append(t)
result.append(SampleInput(tensors))
Expand Down

0 comments on commit 8f9dc98

Please sign in to comment.