Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated derivative rules for complex QR decomposition #48489

Closed
2 changes: 1 addition & 1 deletion test/test_autograd.py
Expand Up @@ -5039,7 +5039,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub',
'exp', 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul',
'addcdiv', 'linalg.tensorinv', 'matrix_exp'] + separate_complex_tests
'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', ] + separate_complex_tests

def add_test(
name,
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Expand Up @@ -78,7 +78,7 @@
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh',
'matrix_exp', 'linalg_eigh', 'qr',
}

# Some operators invalidate the grad_accumulator. Let's reset it.
Expand Down
79 changes: 35 additions & 44 deletions torch/csrc/autograd/FunctionsManual.cpp
Expand Up @@ -2006,67 +2006,58 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
const Tensor& A,
const Tensor& Q,
const Tensor& R) -> Tensor {
// For square and deep (tall) case we refer
// Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
// Algebra Functions with Application in Optimum Experimental Design
// (Extended Version) The derivative for the QR decomposition is adapted
// from Eq. 42 of the above reference.

// Compute R (R')^{T}
// For square and deep (tall) case we refer:
// Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra.
albanD marked this conversation as resolved.
Show resolved Hide resolved
// https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition)
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html

// Compute R grad_R^H
Tensor R_term;
if (grad_R.defined()) {
R_term = at::matmul(R, grad_R.transpose(-2, -1));
R_term = at::matmul(R, grad_R.conj().transpose(-2, -1));
} else {
// R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// Compute Q^{T} Q'
// Compute grad_Q^H Q
Tensor Q_term;
if (grad_Q.defined()) {
Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q);
} else {
// Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// We want to compute: (rhs_solve_1 . R^{-T})
// Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
Tensor M = R_term - Q_term;

// Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity
Tensor M_tril = at::tril(M);
M = M_tril + M_tril.conj().transpose(-2, -1);
M.diagonal(0, -2, -1).mul_(0.5);

Tensor rhs_term;
if (grad_Q.defined()) {
rhs_term = grad_Q + at::matmul(Q, M);
} else {
rhs_term = at::matmul(Q, M);
}

// We want to compute: (rhs_term @ R^{-H})
// Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H
// Since R is upper triangular, we can do this using
// triangular_solve(rhs_solve_1^{T}, R)^{T}
auto rhs_solve_1 =
R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
Tensor solve_soln_1;
std::tie(solve_soln_1, std::ignore) = at::triangular_solve(
rhs_solve_1.transpose(-2, -1),
// triangular_solve(rhs_term^H, R)^H
Tensor grad_A;
std::tie(grad_A, std::ignore) = at::triangular_solve(
rhs_term.conj().transpose(-2, -1),
R,
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
Tensor grad_A;
if (grad_R.defined()) {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
} else {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
}

// Successive computations involve computation of QQ^{T} which is identity when A is square
if (A.size(-1) != A.size(-2)) {
albanD marked this conversation as resolved.
Show resolved Hide resolved
Tensor rhs_solve_2;
// We use the same trick from above for this computation
if (grad_Q.defined()) {
rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
} else {
rhs_solve_2 = -at::matmul(Q, Q_term);
}
Tensor solve_soln_2;
std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
grad_A.add_(solve_soln_2.transpose(-2, -1));
}
return grad_A;
return grad_A.conj().transpose(-2, -1);
};

auto m = self.size(-2);
Expand All @@ -2087,7 +2078,7 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
// grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
// To obtain grad_X we reuse the gradient formula from the square case.
// Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
// where grad_Q_prime = grad_Q + Y @ grad_V.T
// where grad_Q_prime = grad_Q + Y @ grad_V^H
// and grad_Y = Q @ grad_V.
// Then concatenate grads to get grad_A = [grad_X | grad_Y].

Expand All @@ -2099,8 +2090,8 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
grad_V = grad_R.narrow(-1, m, n - m);
// reuse grad_R to store grad_U
grad_R = grad_R.narrow(-1, 0, m);
// grad_Q_prime starts with the value of Y @ grad_V.T
grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1));
// grad_Q_prime starts with the value of Y @ grad_V^H
grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1));
} else {
// when grad_R is not defined then grad_V and grad_Q_prime
// get initialized with zeros
Expand Down