Skip to content

Commit 0aef44c

Browse files
IvanYashchukfacebook-github-bot
authored andcommitted
Add forward AD for torch.linalg.eigh (#62163)
Summary: This PR adds forward mode differentiation for `torch.linalg.eigh` and a few other functions required for tests to pass. For some reason running tests for `torch.linalg.eigvalsh` and complex `torch.linalg.eigh` hangs. These tests are skipped for now. cc ezyang albanD zou3519 gqchen pearu nikitaved soulitzer Lezcano Varal7 jianyuh mruberry heitorschueroff walterddr IvanYashchuk xwang233 Pull Request resolved: #62163 Reviewed By: jbschlosser Differential Revision: D30903988 Pulled By: albanD fbshipit-source-id: d6a74adb9e6d2f4be8ac707848ecabf06d629823
1 parent 35c82db commit 0aef44c

File tree

4 files changed

+74
-3
lines changed

4 files changed

+74
-3
lines changed

tools/autograd/derivatives.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@
479479

480480
- name: diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)
481481
self: diagonal_backward(grad, self.sizes(), offset, dim1, dim2)
482+
result: auto_linear
482483

483484
- name: dist(Tensor self, Tensor other, Scalar p=2) -> Tensor
484485
self: norm_backward(grad, self - other, p, result)
@@ -579,10 +580,12 @@
579580

580581
- name: fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)
581582
self: zeros_like(grad)
583+
result: self_t.fill_(0)
582584

583585
- name: fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)
584586
self: zeros_like(grad)
585587
value: grad.sum()
588+
result: self_t.fill_(value_t)
586589

587590
- name: floor(Tensor self) -> Tensor
588591
self: zeros_like(grad)
@@ -1338,6 +1341,8 @@
13381341

13391342
- name: linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors)
13401343
self: eigh_backward(grads, self, /*eigenvectors=*/true, eigenvalues, eigenvectors)
1344+
eigenvalues: eigh_jvp_eigenvalues(self_t, eigenvalues, eigenvectors)
1345+
eigenvectors: eigh_jvp_eigenvectors(self_t, eigenvalues, eigenvectors)
13411346

13421347
- name: linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)
13431348
self: linalg_eig_backward(grads, self, eigenvalues, eigenvectors)

torch/csrc/autograd/FunctionsManual.cpp

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2409,6 +2409,54 @@ Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads,
24092409
}
24102410
}
24112411

2412+
// jvp functions for eigenvalues and eigenvectors are separate
2413+
// because currently forward AD only works with one rule per output
2414+
Tensor eigh_jvp_eigenvalues(
2415+
const Tensor& input_tangent,
2416+
const Tensor& eigenvalues,
2417+
const Tensor& eigenvectors) {
2418+
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
2419+
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
2420+
// Section 3.1 Eigenvalues and eigenvectors
2421+
2422+
// TODO: gradcheck from test_ops.py hangs with complex inputs
2423+
TORCH_CHECK_NOT_IMPLEMENTED(
2424+
!input_tangent.is_complex(),
2425+
"the derivative for 'eigh' with complex inputs is not implemented.");
2426+
2427+
// see the note in the implementation of eigh_backward that tangent should be Hermitian
2428+
auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());
2429+
2430+
auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
2431+
auto eigenvalues_tangent = tmp.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1);
2432+
if (eigenvalues_tangent.is_complex()) {
2433+
return at::real(eigenvalues_tangent);
2434+
}
2435+
return eigenvalues_tangent;
2436+
}
2437+
2438+
Tensor eigh_jvp_eigenvectors(
2439+
const Tensor& input_tangent,
2440+
const Tensor& eigenvalues,
2441+
const Tensor& eigenvectors) {
2442+
// An extended collection of matrix derivative results for forward and reverse mode automatic differentiation
2443+
// https://ora.ox.ac.uk/objects/uuid:8d0c0a29-c92b-4153-a1d2-38b276e93124
2444+
// Section 3.1 Eigenvalues and eigenvectors
2445+
2446+
TORCH_CHECK_NOT_IMPLEMENTED(
2447+
!input_tangent.is_complex(),
2448+
"the derivative for 'eigh' with complex inputs is not implemented.");
2449+
2450+
auto E = eigenvalues.unsqueeze(-2) - eigenvalues.unsqueeze(-1);
2451+
E.diagonal(/*offset=*/0, /*dim1=*/-2, /*dim2=*/-1).fill_(INFINITY);
2452+
2453+
// see the note in the implementation of eigh_backward that tangent should be Hermitian
2454+
auto hermitian_tangent = 0.5*(input_tangent + input_tangent.transpose(-2, -1).conj());
2455+
2456+
auto tmp = at::matmul(at::matmul(eigenvectors.transpose(-2, -1).conj(), hermitian_tangent), eigenvectors);
2457+
return at::matmul(eigenvectors, tmp.div(E));
2458+
}
2459+
24122460
Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
24132461
bool eigenvectors, const Tensor& L, const Tensor& V) {
24142462
// This function is used for both torch.symeig and torch.linalg.eigh.

torch/csrc/autograd/FunctionsManual.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Tensor slice_backward_wrapper(
157157
int64_t step);
158158
Tensor linalg_eig_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
159159
const Tensor& L, const Tensor& V);
160+
Tensor eigh_jvp_eigenvectors(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
161+
Tensor eigh_jvp_eigenvalues(const Tensor& input_tangent, const Tensor& eigenvalues, const Tensor& eigenvectors);
160162
Tensor eigh_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
161163
bool eigenvectors, const Tensor& L, const Tensor& V);
162164
std::tuple<Tensor, Tensor> triangular_solve_backward(

torch/testing/_internal/common_methods_invocations.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6583,6 +6583,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
65836583
OpInfo('diagonal',
65846584
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
65856585
supports_out=False,
6586+
supports_forward_ad=True,
65866587
sample_inputs_func=sample_inputs_diagonal_diag_embed),
65876588
OpInfo('eq',
65886589
dtypes=all_types_and_complex_and(torch.bool, torch.bfloat16, torch.float16),
@@ -6969,16 +6970,25 @@ def wrapper(x: np.ndarray, *args, **kwargs):
69696970
aten_name='linalg_eigh',
69706971
dtypes=floating_and_complex_types(),
69716972
check_batched_gradgrad=False,
6973+
supports_forward_ad=True,
69726974
sample_inputs_func=sample_inputs_linalg_eigh,
69736975
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
6974-
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
6976+
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
6977+
skips=(
6978+
# Gradcheck for complex hangs for this function, therefore it raises NotImplementedError for now
6979+
SkipInfo('TestGradients', 'test_forward_mode_AD', dtypes=complex_types()),),
6980+
),
69756981
OpInfo('linalg.eigvalsh',
69766982
aten_name='linalg_eigvalsh',
69776983
dtypes=floating_and_complex_types(),
69786984
check_batched_gradgrad=False,
69796985
sample_inputs_func=sample_inputs_linalg_eigh,
69806986
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
6981-
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],),
6987+
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
6988+
skips=(
6989+
# Gradcheck hangs for this function
6990+
SkipInfo('TestGradients', 'test_forward_mode_AD'),),
6991+
),
69826992
OpInfo('linalg.householder_product',
69836993
aten_name='linalg_householder_product',
69846994
op=torch.linalg.householder_product,
@@ -8429,7 +8439,11 @@ def wrapper(x: np.ndarray, *args, **kwargs):
84298439
check_batched_gradgrad=False,
84308440
sample_inputs_func=sample_inputs_linalg_pinv_hermitian,
84318441
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
8432-
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack]),
8442+
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
8443+
skips=(
8444+
# Gradcheck hangs for this function
8445+
SkipInfo('TestGradients', 'test_forward_mode_AD'),),
8446+
),
84338447
OpInfo('eig',
84348448
op=torch.eig,
84358449
dtypes=floating_and_complex_types(),
@@ -8448,6 +8462,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
84488462
backward_dtypesIfCUDA=floating_and_complex_types_and(torch.half,
84498463
*[torch.bfloat16] if (SM60OrLater and CUDA11OrLater) else []),
84508464
supports_out=False,
8465+
supports_forward_ad=True,
84518466
sample_inputs_func=sample_inputs_einsum,
84528467
skips=(
84538468
# test does not work with passing lambda for op
@@ -8877,6 +8892,7 @@ def wrapper(x: np.ndarray, *args, **kwargs):
88778892
op=lambda x, scalar: torch.fill_(x.clone(), scalar),
88788893
method_variant=None,
88798894
inplace_variant=torch.Tensor.fill_,
8895+
supports_forward_ad=True,
88808896
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
88818897
supports_out=False,
88828898
skips=(

0 commit comments

Comments
 (0)