Skip to content

Conversation

lezcano
Copy link
Collaborator

@lezcano lezcano commented May 4, 2021

Adds deprecation notes and aliases to the documentation and code of torch. methods pointing to their equivalents in torch.linalg.

It also fixes the function torch.chain_matmul, which was not using the out= parameter when passed.

Script to check the warnings
import torch
from functools import partial

#DEVICE = "cpu"
DEVICE = "cuda"

def square():
    A = torch.rand(3, 3, device=DEVICE)
    A = A.T @ A  + torch.eye(3, device=DEVICE)
    return A

def warn(fun, n_args, has_out=True):
    args = [square() for _ in range(n_args)]
    print(40*"-" + "  NORMAL {}  ".format(fun.__name__ if hasattr(fun, "__name__") else "") + 40*"-")
    out = fun(*args)
    if has_out:
        print(40*"-" + "  OUT {}  ".format(fun.__name__ if hasattr(fun, "__name__") else "") + 40*"-")
        fun(*args, out=out)
    input()
    print()
    print()

# Shows one by one the warnings (press enter)
# The script should be run once with each of the possibilities for the global variable DEVICE 
warn(torch.cholesky, 1)
warn(torch.eig, 1, has_out=False)     # Warn in common
warn(torch.symeig, 1, has_out=False)  # Warn in common
warn(torch.svd, 1)
warn(torch.qr, 1)
warn(torch.matrix_rank, 1, has_out=False)
warn(partial(torch.matrix_rank, tol=1e-4), 1, has_out=False)  # Has a different path
warn(torch.chain_matmul, 2)
warn(torch.solve, 2)
warn(torch.lstsq, 2)
Script to check that the suggested replacements are correct
import torch
from functools import partial
from itertools import product

def square(device):
    for size in ((2, 3, 3), (3, 3)):
        A = torch.rand(*size, device=device)
        yield A.transpose(-2, -1) @ A  + torch.eye(3, device=DEVICE)

def non_square(device):
    for size in ((2, 3, 4), (2, 4, 3), (3, 4), (4, 3), (2, 3, 3), (3, 3)):
        yield torch.rand(*size, device=device)

def assert_eq(f1, f2, *args):
    X1 = f1(*args)
    X2 = f2(*args)
    if isinstance(X1, tuple):
        for t1, t2 in zip(X1, X2):
            if not torch.allclose(t1, t2, atol=1e-3, rtol=1e-4):
                print(t1, t2)
                raise RuntimeError(f"{f1.__name__}\n{f2.__name__}\n{args}")
    else:
        if not torch.allclose(X1, X2, atol=1e-3, rtol=1e-4):
            print(X1, X2)
            raise RuntimeError(f"{f1.__name__}\n{f2.__name__}\n{args}")


def cholesky(device):
    def my_cholesky(A, upper):
        if upper:
            return torch.linalg.cholesky(A.transpose(-2, -1)).transpose(-2, -1)
        else:
            return torch.linalg.cholesky(A)

    for t, upper in product(square(device), [True, False]):
        assert_eq(torch.cholesky, my_cholesky, t, upper)


def symeig(device):
    def my_symeig(A, eigenvectors):
        if eigenvectors:
            L, V = torch.linalg.eigh(A)
            return L, V.abs()
        else:
            return torch.linalg.eigvalsh(A)

    def torch_symeig(A, eigenvectors):
        if eigenvectors:
            L, V = torch.symeig(A, eigenvectors)
            return L, V.abs()
        else:
            return torch.symeig(A, eigenvectors)[0]

    for t, eigenvectors in product(square(device), [True, False]):
        assert_eq(torch_symeig, my_symeig, t, eigenvectors)

def svd(device):
    def my_svd(A, some, compute_uv):
        if compute_uv:
            U, S, Vh = torch.linalg.svd(A, full_matrices=not some)
            return U.abs(), S, Vh.transpose(-2, -1).conj().abs()
        else:
            S = torch.linalg.svdvals(A)
            return S

    def torch_svd(A, some, compute_uv):
        if compute_uv:
            U, S, V = torch.svd(A, some, True)
            return U.abs(), S, V.abs()
        else:
            _, S, _ = torch.svd(A, some, True)
            return S

    for t, some, compute_uv in product(non_square(device), [True, False], [True, False]):
        assert_eq(torch_svd, my_svd, t, some, compute_uv)

def qr(device):
    def my_qr(A, some):
        return torch.linalg.qr(A, "reduced" if some else "complete")

    for t, some in product(non_square(device), [True, False]):
        assert_eq(torch.qr, my_qr, t, some)


def matrix_rank(device):
    def my_matrix_rank(A, symmetric):
        return torch.linalg.matrix_rank(A, hermitian=symmetric)

    for t, symmetric in product(square(device), [True, False]):
        assert_eq(torch.matrix_rank, my_matrix_rank, t, symmetric)

def chain_matmul(device):
    def my_multi_dot(*tensors):
        return torch.linalg.multi_dot(tensors)

    make_t = partial(torch.rand, device=device)
    assert_eq(torch.chain_matmul, my_multi_dot, make_t(3,2), make_t(2,3))
    assert_eq(torch.chain_matmul, my_multi_dot, make_t(3,3), make_t(3,3))

def solve(device):
    def my_solve(B, A):
        return torch.linalg.solve(A, B)

    def torch_solve(B, A):
        return torch.solve(B, A).solution

    for t in square(device):
        assert_eq(torch_solve, my_solve, t, t.clone())


def lstsq(device):
    def my_lstsq(B, A):
        return torch.linalg.lstsq(A, B).solution

    def torch_lstsq(B, A):
        return torch.lstsq(B, A).solution[:A.size(1)]

    for t in non_square(device):
        if t.ndim == 3: # torch.lstsq cuda does not support batches
            continue
        if DEVICE == "cuda" and t.size(-2) < t.size(-1): # Not implemented
            continue
        assert_eq(torch_lstsq, my_lstsq, t, t.clone())


# If none of them throws, we're good

# eig is too different, we will not compare it
for device in ["cpu", "cuda"]:
    cholesky(device)
    symeig(device)
    svd(device)
    qr(device)
    matrix_rank(device)
    chain_matmul(device)
    solve(device)
    lstsq(device)

@lezcano lezcano added the module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul label May 4, 2021
@lezcano lezcano requested a review from mruberry May 4, 2021 15:49
@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented May 4, 2021

💊 CI failures summary and remediations

As of commit 5a04311 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



❄️ 1 failure tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build pytorch_linux_xenial_py3_clang5_asan_test2 (1/1)

Step: "Run tests" (full log | diagnosis details | 🔁 rerun) ❄️

May 06 15:34:54 unknown file: Failure
May 06 15:34:54 [       OK ] Kernel.Softmax2D (31 ms)
May 06 15:34:54 [ RUN      ] Kernel.Softmax3D
May 06 15:34:54 [       OK ] Kernel.Softmax3D (207 ms)
May 06 15:34:54 [ RUN      ] Kernel.Softmax4D
May 06 15:34:54 [       OK ] Kernel.Softmax4D (195 ms)
May 06 15:34:54 [ RUN      ] Kernel.ConstantTensors
May 06 15:34:54 [       OK ] Kernel.ConstantTensors (21 ms)
May 06 15:34:54 [ RUN      ] Kernel.ConstantTensorsNonContiguous
May 06 15:34:54 [       OK ] Kernel.ConstantTensorsNonContiguous (20 ms)
May 06 15:34:54 [ RUN      ] Kernel.CodegenInspection
May 06 15:34:54 unknown file: Failure
May 06 15:34:54 C++ exception with description "Expected to find ".text" but did not find it
May 06 15:34:54 Searched string:
May 06 15:34:54 From CHECK: .text
May 06 15:34:54 " thrown in the test body.
May 06 15:34:54 [  FAILED  ] Kernel.CodegenInspection (4 ms)
May 06 15:34:54 [----------] 14 tests from Kernel (694 ms total)
May 06 15:34:54 
May 06 15:34:54 [----------] 140 tests from LoopNest
May 06 15:34:54 [ RUN      ] LoopNest.ExprSimple01
May 06 15:34:54 [       OK ] LoopNest.ExprSimple01 (1 ms)

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

"torch.lstsq is deprecated in favor of torch.linalg.lstsq and will be removed in a future PyTorch release.\n",
"torch.linalg.lstsq has reversed arguments and does not return the QR decomposition in "
"the returned tuple, (it returns other information about the problem).\n",
"X, _ = torch.lstsq(B, A).solution[:A.size(1)]\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a very specific invocation and manipulation of the result of torch.lstsq()

Maybe there's a more generic way to tell users to replace torch.lstsq() with torch.linalg.lstsq(), or maybe we need more time to deprecate torch.lstsq()?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also comes from a previous wrong design. torch.lstsq returns the solution and the residuals packed into the "solution" tensor.

Any user would have to unpack the solution and the residuals. That slicing there does the unpacking. I'll add this point to the warning.

TORCH_WARN_ONCE(
"torch.solve is deprecated in favor of torch.linalg.solve",
"and will be removed in a future PyTorch release.\n",
"torch.linalg.solve has its arguments reversed and does not return the LU factorization.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should a user do if they want the LU factorization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sadly we do not have a good solution for this yet, as torch.lu is split into several functions (torch.lu and torch.lu_unpack). I would not worry about this too much, as that return was caused by a wrong design: (t's returning some internal part of the implementation that is not linked to the result (torch.solve could be implemented via QR or SVD as well).

TORCH_WARN_ONCE(
"torch.eig is deprecated in favor of torch.linalg.eig and will be removed in a future ",
"PyTorch release.\n",
"torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors.\n",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"than real tensors mimicking complex tensors."

cc @anjali411 for a language sanity check

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @lezcano! Language overall looks really good; I made a few suggestions inline

We'll have to separate out each deprecation into its own PR, however, and be sure that those PRs eliminate all uses of the deprecated function within the PyTorch code base, too (except for tests), so that when calling torch.linalg.svd, for example, we don't accidentally call into torch.svd and throw an error

@lezcano lezcano closed this May 7, 2021
facebook-github-bot pushed a commit that referenced this pull request May 12, 2021
Summary:
This PR adds a note to the documentation that torch.svd is deprecated together with an upgrade guide on how to use `torch.linalg.svd` and `torch.linalg.svdvals` (Lezcano's instructions from #57549).
In addition, all usage of the old svd function is replaced with a new one from torch.linalg module, except for the `at::linalg_pinv` function, that fails the XLA CI build (pytorch/xla#2755, see failure in draft PR #57772).

Pull Request resolved: #57981

Reviewed By: ngimel

Differential Revision: D28345558

Pulled By: mruberry

fbshipit-source-id: 02dd9ae6efe975026e80ca128e9b91dfc65d7213
@lezcano lezcano deleted the lezcano-deprecation branch May 13, 2021 11:24
krshrimali pushed a commit to krshrimali/pytorch that referenced this pull request May 19, 2021
Summary:
This PR adds a note to the documentation that torch.svd is deprecated together with an upgrade guide on how to use `torch.linalg.svd` and `torch.linalg.svdvals` (Lezcano's instructions from pytorch#57549).
In addition, all usage of the old svd function is replaced with a new one from torch.linalg module, except for the `at::linalg_pinv` function, that fails the XLA CI build (pytorch/xla#2755, see failure in draft PR pytorch#57772).

Pull Request resolved: pytorch#57981

Reviewed By: ngimel

Differential Revision: D28345558

Pulled By: mruberry

fbshipit-source-id: 02dd9ae6efe975026e80ca128e9b91dfc65d7213
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed module: linear algebra Issues related to specialized linear algebra operations in PyTorch; includes matrix multiply matmul open source

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants