Skip to content

Commit

Permalink
Updated derivative rules for complex QR decomposition (#48489)
Browse files Browse the repository at this point in the history
Summary:
Updated `qr_backward` to work correctly for complex-valued inputs.
Added `torch.qr` to list of complex tests.

The previous implementation for real-valued differentiation used equation 42 from https://arxiv.org/abs/1001.1654
The current implementation is a bit simpler but the result for the real-valued input case is the same and all tests still pass.
Derivation of complex-valued QR differentiation https://giggleliu.github.io/2019/04/02/einsumbp.html

Ref. #33152

Pull Request resolved: #48489

Reviewed By: bdhirsh

Differential Revision: D25272344

Pulled By: albanD

fbshipit-source-id: b53c1fca1683f4aee5f4d5ce3cab9e559170e7cf
  • Loading branch information
IvanYashchuk authored and facebook-github-bot committed Dec 11, 2020
1 parent e3542d2 commit 6c1b405
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 46 deletions.
2 changes: 1 addition & 1 deletion test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4927,7 +4927,7 @@ def run_functional_checks(test_case, test_name, name, apply_fn, run_grad_checks,
'cosh', '__rmul__', 'sgn', 'abs', 'dot', 'vdot', 'tensor_split', 'matmul',
'bmm', 'mv', 'ger', 'diagonal', 'atan', 'angle', 'tanh', 'fill_', 'sub',
'exp', 'mean', 'inverse', 'triangular_solve', 'solve', 'addcmul',
'addcdiv', 'linalg.tensorinv', 'matrix_exp'] + separate_complex_tests
'addcdiv', 'linalg.tensorinv', 'matrix_exp', 'qr', ] + separate_complex_tests

def add_test(
name,
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero', 'mean', 'inverse', 'solve', 'linalg_cholesky', 'addcmul', 'addcdiv',
'matrix_exp', 'linalg_eigh', 'cholesky_solve',
'matrix_exp', 'linalg_eigh', 'cholesky_solve', 'qr',
'_fft_c2c', '_fft_r2c',
}

Expand Down
79 changes: 35 additions & 44 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2006,67 +2006,58 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
const Tensor& A,
const Tensor& Q,
const Tensor& R) -> Tensor {
// For square and deep (tall) case we refer
// Walter, S.F and Lehmann, L., Algorithmic Differentiation of Linear
// Algebra Functions with Application in Optimum Experimental Design
// (Extended Version) The derivative for the QR decomposition is adapted
// from Eq. 42 of the above reference.

// Compute R (R')^{T}
// For square and deep (tall) case we refer:
// Matthias Seeger, Asmus Hetzel, Zhenwen Dai, Eric Meissner, Neil D. Lawrence (2018). Auto-Differentiating Linear Algebra.
// https://arxiv.org/abs/1710.08717 Section 4.3 LQ Decomposition (Note that LQ decomposition is the transpose of QR decomposition)
// Hai-Jun Liao, Jin-Guo Liu, Lei Wang, Tao Xiang (2019). Differentiable Programming Tensor Networks.
// https://arxiv.org/abs/1903.09650 Section 3. QR factorization
// For derivations of complex-valued input case, see https://giggleliu.github.io/2019/04/02/einsumbp.html

// Compute R grad_R^H
Tensor R_term;
if (grad_R.defined()) {
R_term = at::matmul(R, grad_R.transpose(-2, -1));
R_term = at::matmul(R, grad_R.conj().transpose(-2, -1));
} else {
// R is ... x N x N, grad_R is ... x N x N and grad_R.T is ... x N x N
R_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// Compute Q^{T} Q'
// Compute grad_Q^H Q
Tensor Q_term;
if (grad_Q.defined()) {
Q_term = at::matmul(Q.transpose(-2, -1), grad_Q);
Q_term = at::matmul(grad_Q.conj().transpose(-2, -1), Q);
} else {
// Q is ... x M x N, Q.T is ... x N x M and grad_Q is ... x M x N
Q_term = at::zeros_like(R, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
}

// We want to compute: (rhs_solve_1 . R^{-T})
// Note that (rhs_solve_1 . R^{-T}) = (R^{-1} . rhs_solve_1^{T})^{T}
Tensor M = R_term - Q_term;

// Compute M = (tril(M) + tril(M).conj().transpose(-2, -1)) * 0.5 Identity
Tensor M_tril = at::tril(M);
M = M_tril + M_tril.conj().transpose(-2, -1);
M.diagonal(0, -2, -1).mul_(0.5);

Tensor rhs_term;
if (grad_Q.defined()) {
rhs_term = grad_Q + at::matmul(Q, M);
} else {
rhs_term = at::matmul(Q, M);
}

// We want to compute: (rhs_term @ R^{-H})
// Note that (rhs_term @ R^{-H}) = (R^{-1} @ rhs_solve_1^H)^H
// Since R is upper triangular, we can do this using
// triangular_solve(rhs_solve_1^{T}, R)^{T}
auto rhs_solve_1 =
R_term - R_term.transpose(-2, -1) + Q_term - Q_term.transpose(-2, -1);
rhs_solve_1 = at::tril(rhs_solve_1, /*k=*/-1);
Tensor solve_soln_1;
std::tie(solve_soln_1, std::ignore) = at::triangular_solve(
rhs_solve_1.transpose(-2, -1),
// triangular_solve(rhs_term^H, R)^H
Tensor grad_A;
std::tie(grad_A, std::ignore) = at::triangular_solve(
rhs_term.conj().transpose(-2, -1),
R,
/*upper=*/true,
/*transpose=*/false,
/*unitriangular=*/false);
Tensor grad_A;
if (grad_R.defined()) {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1) + grad_R);
} else {
grad_A = at::matmul(Q, solve_soln_1.transpose(-2, -1));
}

// Successive computations involve computation of QQ^{T} which is identity when A is square
if (A.size(-1) != A.size(-2)) {
Tensor rhs_solve_2;
// We use the same trick from above for this computation
if (grad_Q.defined()) {
rhs_solve_2 = grad_Q - at::matmul(Q, Q_term);
} else {
rhs_solve_2 = -at::matmul(Q, Q_term);
}
Tensor solve_soln_2;
std::tie(solve_soln_2, std::ignore) = at::triangular_solve(rhs_solve_2.transpose(-2, -1), R,
/*upper=*/true, /*transpose=*/false,
/*unitriangular=*/false);
grad_A.add_(solve_soln_2.transpose(-2, -1));
}
return grad_A;
return grad_A.conj().transpose(-2, -1);
};

auto m = self.size(-2);
Expand All @@ -2087,7 +2078,7 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
// grad_R = [grad_U | grad_V] and grad_A = [grad_X | grad_Y].
// To obtain grad_X we reuse the gradient formula from the square case.
// Formulae: grad_X = square_case_grad(grad_Q_prime, grad_U, Q, U),
// where grad_Q_prime = grad_Q + Y @ grad_V.T
// where grad_Q_prime = grad_Q + Y @ grad_V^H
// and grad_Y = Q @ grad_V.
// Then concatenate grads to get grad_A = [grad_X | grad_Y].

Expand All @@ -2099,8 +2090,8 @@ Tensor qr_backward(const std::vector<torch::autograd::Variable> &grads, const Te
grad_V = grad_R.narrow(-1, m, n - m);
// reuse grad_R to store grad_U
grad_R = grad_R.narrow(-1, 0, m);
// grad_Q_prime starts with the value of Y @ grad_V.T
grad_Q_prime = at::matmul(Y, grad_V.transpose(-2, -1));
// grad_Q_prime starts with the value of Y @ grad_V^H
grad_Q_prime = at::matmul(Y, grad_V.conj().transpose(-2, -1));
} else {
// when grad_R is not defined then grad_V and grad_Q_prime
// get initialized with zeros
Expand Down

0 comments on commit 6c1b405

Please sign in to comment.