Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve torch.linalg.qr #50046

Closed
wants to merge 13 commits into from
4 changes: 2 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -1011,13 +1011,13 @@ std::tuple<Tensor, Tensor> _linalg_qr_helper_cpu(const Tensor& self, std::string

std::tuple<Tensor,Tensor> linalg_qr(const Tensor& self, std::string mode) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
"qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
return at::_linalg_qr_helper(self, mode);
}

std::tuple<Tensor&,Tensor&> linalg_qr_out(Tensor& Q, Tensor& R, const Tensor& self, std::string mode) {
TORCH_CHECK(self.dim() >= 2,
"self should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
"qr input should have at least 2 dimensions, but has ", self.dim(), " dimensions instead");
Tensor Q_tmp, R_tmp;
std::tie(Q_tmp, R_tmp) = at::_linalg_qr_helper(self, mode);
at::native::resize_output(Q, Q_tmp.sizes());
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -206,7 +206,8 @@ static inline std::tuple<bool, bool> _parse_qr_mode(std::string mode) {
compute_q = false;
reduced = true; // this is actually irrelevant in this mode
} else {
TORCH_CHECK(false, "Unrecognized mode '", mode, "'");
TORCH_CHECK(false, "qr received unrecognized mode '", mode,
"' but expected one of 'reduced' (default), 'r', or 'complete'");
}
return std::make_tuple(compute_q, reduced);
}
Expand Down
11 changes: 0 additions & 11 deletions test/test_autograd.py
Expand Up @@ -4941,17 +4941,6 @@ def assert_only_first_requires_grad(res):
return_counts=return_counts)
assert_only_first_requires_grad(res)

def test_linalg_qr_r(self):
# torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
# without 'q' you cannot compute the backward pass. Check that
# linalg_qr_backward complains cleanly in that case.
inp = torch.randn((5, 7), requires_grad=True)
q, r = torch.linalg.qr(inp, mode='r')
assert q.shape == (0,) # empty tensor
b = torch.sum(r)
with self.assertRaisesRegex(RuntimeError,
"linalg_qr_backward: cannot compute backward"):
b.backward()

def index_perm_variable(shape, max_indices):
if not isinstance(shape, tuple):
Expand Down
67 changes: 58 additions & 9 deletions test/test_linalg.py
Expand Up @@ -2787,12 +2787,34 @@ def test_qr_vs_numpy(self, device, dtype):
exp_r = np.linalg.qr(np_t, mode='r')
q, r = torch.linalg.qr(t, mode='r')
# check that q is empty
assert q.shape == (0,)
assert q.dtype == t.dtype
assert q.device == t.device
self.assertEqual(q.shape, (0,))
self.assertEqual(q.dtype, t.dtype)
self.assertEqual(q.device, t.device)
# check r
self.assertEqual(r, exp_r)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float)
def test_linalg_qr_autograd_errors(self, device, dtype):
# torch.linalg.qr(mode='r') returns only 'r' and discards 'q', but
# without 'q' you cannot compute the backward pass. Check that
# linalg_qr_backward complains cleanly in that case.
inp = torch.randn((5, 7), device=device, dtype=dtype, requires_grad=True)
q, r = torch.linalg.qr(inp, mode='r')
self.assertEqual(q.shape, (0,)) # empty tensor
b = torch.sum(r)
with self.assertRaisesRegex(RuntimeError,
"The derivative of qr is not implemented when mode='r'"):
b.backward()
#
inp = torch.randn((7, 5), device=device, dtype=dtype, requires_grad=True)
q, r = torch.linalg.qr(inp, mode='complete')
b = torch.sum(r)
with self.assertRaisesRegex(RuntimeError,
"The derivative of qr is not implemented when mode='complete' and nrows > ncols"):
b.backward()

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
Expand All @@ -2806,10 +2828,17 @@ def np_qr_batched(a, mode):
all_q = []
all_r = []
for matrix in a:
q, r = np.linalg.qr(matrix, mode=mode)
all_q.append(q)
all_r.append(r)
return np.array(all_q), np.array(all_r)
result = np.linalg.qr(matrix, mode=mode)
if mode == 'r':
all_r.append(result)
else:
q, r = result
all_q.append(q)
all_r.append(r)
if mode == 'r':
return np.array(all_r)
else:
return np.array(all_q), np.array(all_r)

t = torch.randn((3, 7, 5), device=device, dtype=dtype)
np_t = t.cpu().numpy()
Expand All @@ -2818,6 +2847,15 @@ def np_qr_batched(a, mode):
q, r = torch.linalg.qr(t, mode=mode)
self.assertEqual(q, exp_q)
self.assertEqual(r, exp_r)
# for mode='r' we need a special logic because numpy returns only r
exp_r = np_qr_batched(np_t, mode='r')
q, r = torch.linalg.qr(t, mode='r')
# check that q is empty
self.assertEqual(q.shape, (0,))
self.assertEqual(q.dtype, t.dtype)
self.assertEqual(q.device, t.device)
# check r
self.assertEqual(r, exp_r)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
Expand All @@ -2840,11 +2878,22 @@ def test_qr_out(self, device, dtype):
out = (torch.empty((0), dtype=dtype, device=device),
torch.empty((0), dtype=dtype, device=device))
q2, r2 = torch.linalg.qr(t, mode=mode, out=out)
assert q2 is out[0]
assert r2 is out[1]
self.assertIs(q2, out[0])
self.assertIs(r2, out[1])
self.assertEqual(q2, q)
self.assertEqual(r2, r)

@skipCUDAIfNoMagma
@skipCPUIfNoLapack
@dtypes(torch.float)
def test_qr_error_cases(self, device, dtype):
t1 = torch.randn(5, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, 'qr input should have at least 2 dimensions, but has 1 dimensions instead'):
torch.linalg.qr(t1)
t2 = torch.randn((5, 7), device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, "qr received unrecognized mode 'hello'"):
torch.linalg.qr(t2, mode='hello')

@dtypes(torch.double, torch.cdouble)
def test_einsum(self, device, dtype):
def check(equation, *operands):
Expand Down
23 changes: 11 additions & 12 deletions torch/_torch_docs.py
Expand Up @@ -6676,11 +6676,10 @@ def merge_dicts(*dicts):
If :attr:`some` is ``True``, then this function returns the thin (reduced) QR factorization.
Otherwise, if :attr:`some` is ``False``, this function returns the complete QR factorization.

.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :meth:`~torch.linalg.qr`
instead, which provides a better compatibility with
``numpy.linalg.qr``.
.. warning:: ``torch.qr`` is deprecated. Please use ``torch.linalg.`` :func:`~torch.linalg.qr`
instead.

**Differences with** ``torch.linalg.`` :meth:`~torch.linalg.qr`:
**Differences with** ``torch.linalg.qr``:

* ``torch.linalg.qr`` takes a string parameter ``mode`` instead of ``some``:

Expand All @@ -6698,21 +6697,21 @@ def merge_dicts(*dicts):

.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
and may produce different (valid) decompositions on different device types
and different platforms, depending on the precise version of the
underlying library.
or different platforms.

Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of matrices of dimension :math:`m \times n`.
some (bool, optional): Set to ``True`` for reduced QR decomposition and ``False`` for
complete QR decomposition.
complete QR decomposition. If `k = min(m, n)` then:

* ``some=True`` : returns `(Q, R)` with dimensions (m, k), (k, n) (default)

* ``'some=False'``: returns `(Q, R)` with dimensions (m, m), (m, n)

Keyword args:
out (tuple, optional): tuple of `Q` and `R` tensors
satisfying :code:`input = torch.matmul(Q, R)`.
The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
respectively, where :math:`k = \min(m, n)` if :attr:`some:` is ``True`` and
:math:`k = m` otherwise.
out (tuple, optional): tuple of `Q` and `R` tensors.
The dimensions of `Q` and `R` are detailed in the description of :attr:`some` above.

Example::

Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -2078,7 +2078,7 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c
std::string mode, const Tensor& q, const Tensor& r){
bool compute_q, reduced;
std::tie(compute_q, reduced) = at::native::_parse_qr_mode(mode);
TORCH_CHECK(compute_q, "linalg_qr_backward: cannot compute backward if mode='r'. "
TORCH_CHECK(compute_q, "The derivative of qr is not implemented when mode='r'. "
"Please use torch.linalg.qr(..., mode='reduced')");

auto square_deep_case_backward = [](const Tensor& grad_Q,
Expand Down Expand Up @@ -2145,7 +2145,7 @@ Tensor linalg_qr_backward(const std::vector<torch::autograd::Variable> &grads, c

TORCH_CHECK(
((m <= n && (!reduced)) || reduced),
"The derivative is not implemented when nrows > ncols and complete QR. ");
"The derivative of qr is not implemented when mode='complete' and nrows > ncols.");

auto grad_Q = grads[0];
auto grad_R = grads[1];
Expand Down
25 changes: 13 additions & 12 deletions torch/linalg/__init__.py
Expand Up @@ -644,16 +644,15 @@
.. note::
Backpropagation is not supported for ``mode='r'``. Use ``mode='reduced'`` instead.

If you plan to backpropagate through QR, note that the current backward implementation
is only well-defined when the first :math:`\min(input.size(-1), input.size(-2))`
columns of :attr:`input` are linearly independent.
This behavior may change in the future.
Backpropagation is also not supported if the first
:math:`\min(input.size(-1), input.size(-2))` columns of any matrix
in :attr:`input` are not linearly independent. While no error will
be thrown when this occurs the values of the "gradient" produced may
be anything. This behavior may change in the future.

.. note:: This function uses LAPACK for CPU inputs and MAGMA for CUDA inputs,
and may produce different (valid) decompositions on different device types
and different platforms, depending on the precise version of the
underlying library.

or different platforms.
Args:
input (Tensor): the input tensor of size :math:`(*, m, n)` where `*` is zero or more
batch dimensions consisting of matrices of dimension :math:`m \times n`.
Expand All @@ -666,11 +665,8 @@
* ``'r'``: computes only `R`; returns `(Q, R)` where `Q` is empty and `R` has dimensions (k, n)

Keyword args:
out (tuple, optional): tuple of `Q` and `R` tensors
satisfying :code:`input = torch.matmul(Q, R)`.
The dimensions of `Q` and `R` are :math:`(*, m, k)` and :math:`(*, k, n)`
respectively, where :math:`k = \min(m, n)` if :attr:`mode` is `'reduced'` and
:math:`k = m` if :attr:`mode` is `'complete'`.
out (tuple, optional): tuple of `Q` and `R` tensors.
The dimensions of `Q` and `R` are detailed in the description of :attr:`mode` above.

Example::

Expand All @@ -692,6 +688,11 @@
tensor([[ 1., 0., 0.],
[ 0., 1., -0.],
[ 0., -0., 1.]])
>>> q2, r2 = torch.linalg.qr(a, mode='r')
>>> q2
tensor([])
>>> torch.equal(r, r2)
True
>>> a = torch.randn(3, 4, 5)
>>> q, r = torch.linalg.qr(a, mode='complete')
>>> torch.allclose(torch.matmul(q, r), a)
Expand Down