diff --git a/test/test_autograd.py b/test/test_autograd.py index 31ea038a5e330..843b623318c7a 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -465,7 +465,7 @@ def sign_mul_logdet(mat): s.abs_().clamp_(0.0001) for sign in (-1, 1): s[-1] = sign - mat = torch.chain_matmul(u, s.diag(), v.t()).requires_grad_() + mat = torch.linalg.multi_dot([u, s.diag(), v.t()]).requires_grad_() gradcheck(sign_mul_logdet, mat) gradgradcheck(sign_mul_logdet, mat) diff --git a/torch/nn/utils/spectral_norm.py b/torch/nn/utils/spectral_norm.py index 2ae7af3d78bd1..3e6743bf40291 100644 --- a/torch/nn/utils/spectral_norm.py +++ b/torch/nn/utils/spectral_norm.py @@ -108,7 +108,7 @@ def _solve_v_and_rescale(self, weight_mat, u, target_sigma): # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` # (the invariant at top of this class) and `u @ W @ v = sigma`. # This uses pinverse in case W^T W is not invertible. - v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) + v = torch.linalg.multi_dot([weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)]).squeeze(1) return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) @staticmethod diff --git a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py index 3d405c3191c7b..54e8f25b9f2e1 100644 --- a/torch/testing/_internal/distributed/rpc/dist_autograd_test.py +++ b/torch/testing/_internal/distributed/rpc/dist_autograd_test.py @@ -1091,7 +1091,7 @@ def test_backward_different_tensor_dims(self): for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.REMOTE]: with dist_autograd.context() as context_id: val = self._exec_func(exec_mode, torch.matmul, t1, t2) - val = self._exec_func(exec_mode, torch.chain_matmul, [val, t3, t4]) + val = self._exec_func(exec_mode, torch.linalg.multi_dot, (val, t3, t4)) loss = val.sum() ret = self._verify_backwards( @@ -1132,7 +1132,7 @@ def test_backward_multiple_output_tensors(self): t2 = tensor_list[2] t3 = tensor_list[4] - val = self._exec_func(exec_mode, torch.chain_matmul, [t1, t2, t3]) + val = self._exec_func(exec_mode, torch.linalg.multi_dot, (t1, t2, t3)) loss = val.sum() ret = self._verify_backwards( @@ -1368,7 +1368,7 @@ def _complex_python_udf(t1, t2): t3 = torch.nn.functional.linear(t1, t2) t4 = torch.nn.functional.linear(t2, t3) t5 = torch.nn.functional.linear(t3, t4) - return torch.chain_matmul(t1, t2, t3, t4, t5) + return torch.linalg.multi_dot([t1, t2, t3, t4, t5]) @dist_init def test_backward_complex_python_udf(self): @@ -1391,7 +1391,7 @@ def test_backward_complex_python_udf(self): def _python_udf_with_backward_error(t1, t2): t3 = t1 + t2 t4 = SimulateBackwardError.apply(t3) - return torch.chain_matmul(t1, t2, t3, t4) + return torch.linalg.multi_dot([t1, t2, t3, t4]) @staticmethod def _nested_rpc_call_backward_error(t1, t2, dst): @@ -1402,7 +1402,7 @@ def _nested_rpc_call_backward_error(t1, t2, dst): DistAutogradTest._python_udf_with_backward_error, args=(t1, t2), ) - return torch.chain_matmul(t1, t2, res) + return torch.linalg.multi_dot([t1, t2, res]) @dist_init def test_backward_python_udf_error(self): @@ -1472,7 +1472,7 @@ def _nested_python_udf(t1, t2, dst): t3 = t1 * t2 t4 = t1 + t2 res = rpc.rpc_sync(worker_name(dst), my_py_add, args=(t3, t4)) - return torch.chain_matmul(t1, t2, t3, t4, res) + return torch.linalg.multi_dot([t1, t2, t3, t4, res]) @dist_init def test_backwards_nested_python_udf(self): @@ -1482,7 +1482,7 @@ def test_backwards_nested_python_udf(self): t3 = t1 * t2 t4 = t1 + t2 res = t3 + t4 - loss = torch.chain_matmul(t1, t2, t3, t4, res).sum() + loss = torch.linalg.multi_dot([t1, t2, t3, t4, res]).sum() torch.autograd.backward([loss]) # Now run distributed autograd.