forked from google/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add initial LOBPCG top-k eigenvalue solver (google#3112)
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop. For details, see jax.experimental.linalg.standard_lobpcg documentation.
- Loading branch information
Vladimir Feinberg
committed
Jun 3, 2022
1 parent
d43cb36
commit b57b52c
Showing
3 changed files
with
563 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -259,3 +259,5 @@ | |
sparsify as sparsify, | ||
SparseTracer as SparseTracer, | ||
) | ||
|
||
from jax.experimental.sparse import linalg |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,345 @@ | ||
# Copyright 2022 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""Sparse linear algebra routines.""" | ||
|
||
from typing import Union | ||
import functools | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
|
||
|
||
def lobpcg_standard( | ||
A: jnp.ndarray, | ||
X: jnp.ndarray, | ||
m: int = 100, | ||
tol: Union[jnp.ndarray, float, None] = None): | ||
"""Compute the top-k standard eigenvalues using the LOBPCG routine. | ||
LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient. | ||
The method enables finding top-k eigenvectors in an accelerator-friendly | ||
manner. | ||
This initial experimental version has several caveats. | ||
- Only the standard eigenvalue problem `A U = lambda U` is supported, | ||
general eigenvalues are not. | ||
- Gradient code is not available. | ||
- f64 will only work where jnp.linalg.eigh is supported for that type. | ||
- Finding the smallest eigenvectors is not yet supported. As a result, | ||
we don't yet support preconditioning, which is mostly needed for this | ||
case. | ||
The implementation is based on [2] and [3]; however, we deviate from these | ||
sources in several ways to improve robustness or facilitate implementation: | ||
- Despite increased iteration cost, we always maintain an orthonormal basis | ||
for the block search directions. | ||
- We change the convergence criterion; see the `tol` argument. | ||
- Soft locking is not implemented. | ||
[1] http://ccm.ucdenver.edu/reports/rep149.pdf | ||
[2] https://arxiv.org/abs/1704.07458 | ||
[3] https://arxiv.org/abs/0705.2626 | ||
Args: | ||
A : An `(n, n)` array representing a square matrix. | ||
X : An `(n, k)` array representing the initial search directions for the `k` | ||
desired top eigenvectors. This need not be orthogonal, but must be | ||
linearly independent. | ||
m : Maximum integer iteration count; LOBPCG will only ever explore (a | ||
subspace of) the Krylov basis `{X, A X, A^2 X, ..., A^m X}`. | ||
tol : A float convergence tolerance; an eigenpair `(lambda, v)` is converged | ||
when its residual L2 norm `r = |A v - lambda v|` is below | ||
`tol * 10 * n * (lambda + |A v|)`, which | ||
roughly estimates the worst-case floating point error for an ideal | ||
eigenvector. If all `k` eigenvectors satisfy the tolerance | ||
comparison, then LOBPCG exits early. If left as None, then this is set | ||
to the float epsilon of `A.dtype`. | ||
debug : A boolean indicating whether to return additional diagnostics. | ||
Returns: | ||
`theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array | ||
of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the | ||
number of iterations performed, and `diagnostics` is a dictionary with debug | ||
information, which is only returned if `debug` is set to true. | ||
Raises: | ||
ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too | ||
large (only `k * 5 < n` supported), or `k == 0`. | ||
""" | ||
return _lobpcg_standard(A, X, m, tol, debug=False) | ||
|
||
@functools.partial(jax.jit, static_argnames=['m', 'debug']) | ||
def _lobpcg_standard( | ||
A: jnp.ndarray, | ||
X: jnp.ndarray, | ||
m: int, | ||
tol: Union[jnp.ndarray, float, None], | ||
debug: bool = False): | ||
"""Computes lobpcg_standard(), possibly with debug diagnostics.""" | ||
|
||
# TODO(vladf): support mixed_precision flag, which allows f64 Rayleigh-Ritz | ||
# with f32 inputs. | ||
mixed_precision = False | ||
|
||
n, k = X.shape | ||
dt = X.dtype | ||
|
||
_check_inputs(A, X) | ||
|
||
if tol is None: | ||
tol = jnp.finfo(dt).eps | ||
|
||
X = _svqb(X, mixed_precision) | ||
P = _extend_basis(X, X.shape[1]) | ||
|
||
# We maintain X, our current list of best eigenvectors, | ||
# P, our search direction, and | ||
# R, our residuals, in a large joint array XPR, column-stacked, so (n, 3*k). | ||
|
||
AX = _mm(A, X) | ||
theta = jnp.sum(X * AX, axis=0, keepdims=True) | ||
R = AX - theta * X | ||
|
||
def cond(state): | ||
i, _X, _P, _R, converged, _ = state | ||
return jnp.logical_and(i < m, converged < k) | ||
|
||
def body(state): | ||
i, X, P, R, _, theta = state | ||
# Invariants: X, P, R kept orthonormal | ||
# Some R, P columns may be 0, but not X. | ||
|
||
# TODO(vladf): support preconditioning for bottom-k eigenvectors | ||
# if M is not None: | ||
# R = M(R) | ||
|
||
# residual basis selection | ||
R = _project_out(jnp.concatenate((X, P), axis=1), R, mixed_precision) | ||
XPR = jnp.concatenate((X, P, R), axis=1) | ||
|
||
# Projected eigensolve. | ||
theta, Q = _rayleigh_ritz_orth(A, XPR, mixed_precision) | ||
|
||
# Eigenvector X extraction | ||
B = Q[:, :k] | ||
normB = jnp.linalg.norm(B, ord=2, axis=0, keepdims=True) | ||
B /= normB | ||
X = _mm(XPR, B) | ||
normX = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True) | ||
X /= normX | ||
|
||
# Difference terms P extraction | ||
# | ||
# In next step of LOBPCG, naively, we'd set | ||
# P = S[:, k:] @ Q[k:, :k] to achieve span(X, P) == span(X, previous X) | ||
# (this is not obvious, see section 4 of [1]). | ||
# | ||
# Instead we orthogonalize concat(0, Q[k:, :k]) against Q[:, :k] | ||
# in the standard basis before mapping with XPR. Since XPR is itself | ||
# orthonormal, the resulting directions are themselves orthonormalized. | ||
# | ||
# [2] leverages Q's existing orthogonality to derive | ||
# an analytic expression for this value based on the quadrant Q[:k,k:] | ||
# (see section 4.2 of [2]). | ||
q, _ = jnp.linalg.qr(Q[:k, k:].T) | ||
diff_rayleigh_ortho = _mm(Q[:, k:], q) | ||
P = _mm(XPR, diff_rayleigh_ortho) | ||
normP = jnp.linalg.norm(P, ord=2, axis=0, keepdims=True) | ||
P /= jnp.where(normP == 0, 1.0, normP) | ||
|
||
# Compute new residuals. | ||
AX = _mm(A, X) | ||
R = AX - theta[jnp.newaxis, :k] * X | ||
resid_norms = jnp.linalg.norm(R, ord=2, axis=0) | ||
|
||
# I tried many variants of hard and soft locking [3]. All of them seemed | ||
# to worsen performance relative to no locking. | ||
# | ||
# Further, I found a more expermental convergence formula compared to what | ||
# is suggested in the literature, loosely based on floating-point | ||
# expectations. | ||
# | ||
# [2] discusses various strategies for this in Sec 5.3. The solution | ||
# they end up with, which estimates operator norm |A| via Gaussian | ||
# products, was too crude in practice (and overly-lax). The Gaussian | ||
# approximation seems like an estimate of the average eigenvalue. | ||
# | ||
# Instead, we test convergence via self-consistency of the eigenpair | ||
# i.e., the residual norm |r| should be small, relative to the floating | ||
# point error we'd expect from computing just the residuals given | ||
# candidate vectors. | ||
# | ||
# sqrt(n) - random walk error from AX multiply | ||
reltol = jnp.linalg.norm(AX, ord=2, axis=0) + theta[:k] | ||
reltol *= n | ||
# Allow some margin for a few element-wise operations. | ||
reltol *= 10 | ||
res_converged = resid_norms < tol * reltol | ||
converged = jnp.sum(res_converged) | ||
|
||
new_state = i + 1, X, P, R, converged, theta[jnp.newaxis, :k] | ||
if debug: | ||
diagnostics = _generate_diagnostics( | ||
XPR, X, P, R, theta, converged, resid_norms / reltol) | ||
new_state = (new_state, diagnostics) | ||
return new_state | ||
|
||
converged = 0 | ||
state = (0, X, P, R, converged, theta) | ||
if debug: | ||
state, diagnostics = jax.lax.scan( | ||
lambda state, _: body(state), state, xs=None, length=m) | ||
else: | ||
state = jax.lax.while_loop(cond, body, state) | ||
i, X, _P, _R, _converged, theta = state | ||
|
||
if debug: | ||
return theta[0, :], X, i, diagnostics | ||
return theta[0, :], X, i | ||
|
||
|
||
def _check_inputs(A, X): | ||
n, k = X.shape | ||
dt = X.dtype | ||
|
||
if k == 0: | ||
raise ValueError(f'must have search dim > 0, got {k}') | ||
|
||
if A.dtype != dt: | ||
raise ValueError(f'A, X must have same dtypes (were {A.dtype}, {dt})') | ||
|
||
if A.shape != (n, n): | ||
raise ValueError(f'A must be ({n}, {n}) matrix A, got {A.shape}') | ||
|
||
if k * 5 >= n: | ||
raise ValueError(f'expected search dim * 5 < matrix dim (got {k * 5}, {n})') | ||
|
||
|
||
def _mm(a, b, precision=jax.lax.Precision.HIGHEST): | ||
return jax.lax.dot(a, b, (precision, precision)) | ||
|
||
def _generate_diagnostics(prev_XPR, X, P, R, theta, converged, adj_resid): | ||
k = X.shape[1] | ||
assert X.shape == P.shape | ||
|
||
diagdiag = lambda x: jnp.diag(jnp.diag(x)) | ||
abserr = lambda x: jnp.abs(x).sum() / (k ** 2) | ||
|
||
XTX = _mm(X.T, X) | ||
DX = diagdiag(XTX) | ||
orthX = abserr(XTX - DX) | ||
|
||
PTP = _mm(P.T, P) | ||
DP = diagdiag(PTP) | ||
orthP = abserr(PTP - DP) | ||
|
||
PX = abserr(X.T @ P) | ||
|
||
prev_basis = prev_XPR.shape[1] - jnp.sum(jnp.all(prev_XPR == 0.0, axis=0)) | ||
|
||
return { | ||
'basis rank': prev_basis, | ||
'X zeros': jnp.sum(jnp.all(X == 0.0, axis=0)), | ||
'P zeros': jnp.sum(jnp.all(P == 0.0, axis=0)), | ||
'lambda history': theta[:k], | ||
'residual history': jnp.linalg.norm(R, axis=0, ord=2), | ||
'converged': converged, | ||
'adjusted residual max': jnp.max(adj_resid), | ||
'adjusted residual p50': jnp.median(adj_resid), | ||
'adjusted residual min': jnp.min(adj_resid), | ||
'X orth': orthX, | ||
'P orth': orthP, | ||
'P.X': PX} | ||
|
||
def _eigh_possibly_mixed(A, mixed_precision): | ||
assert not mixed_precision, 'mixed precision not yet supported' | ||
w, V = jnp.linalg.eigh(A) | ||
return w[::-1], V[:, ::-1] | ||
|
||
|
||
def _svqb(X, mixed_precision): | ||
# https://sdm.lbl.gov/~kewu/ps/45577.html | ||
|
||
norms = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True) | ||
X /= jnp.where(norms == 0, 1.0, norms) | ||
|
||
inner = _mm(X.T, X) | ||
|
||
w, V = _eigh_possibly_mixed(inner, mixed_precision) | ||
|
||
tau = jnp.finfo(X.dtype).eps * w[0] | ||
padded = jnp.maximum(w, tau) | ||
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5) | ||
scaledV = V * sqrted[jnp.newaxis, :] | ||
orthoX = _mm(X, scaledV) | ||
|
||
keep = ((w > tau) * (jnp.diag(inner) > 0.0))[jnp.newaxis, :] | ||
orthoX *= keep | ||
norms = jnp.linalg.norm(orthoX, ord=2, axis=0, keepdims=True) | ||
keep *= norms > 0.0 | ||
orthoX /= jnp.where(keep, norms, 1.0) | ||
return orthoX | ||
|
||
|
||
def _project_out(basis, U, mixed_precision): | ||
# "twice is enough" from shoyer's reference: | ||
# http://slepc.upv.es/documentation/reports/str1.pdf | ||
|
||
for _ in range(2): | ||
U -= _mm(basis, _mm(basis.T, U)) | ||
for _ in range(2): | ||
U = _svqb(U, mixed_precision) | ||
|
||
return U | ||
|
||
|
||
def _rayleigh_ritz_orth(A, S, mixed_precision): | ||
# Classical Rayleigh-Ritz returns w, V satisfying | ||
# (1) S.T A S @ V ~= w * V | ||
# such that (2) V is (S.T S)-orthonormal. | ||
# https://www.netlib.org/lapack/lug/node54.html | ||
# | ||
# Usually it requires solving the complicated standard eigensystem | ||
# U^-T S^T A S U^-1 @ Q = w * Q and then backsolving V = U^-1 Q, | ||
# but if S is standard orthonormal then we just need to find | ||
# eigenvalues of S.T A S. | ||
|
||
SAS = _mm(S.T, _mm(A, S)) | ||
|
||
# Solve the projected subsytem. | ||
# If we could tell to eigh to stop after first k, we would. | ||
return _eigh_possibly_mixed(SAS, mixed_precision) | ||
|
||
|
||
def _extend_basis(X, m): | ||
# Use a block householder reflector to generate orthogonal extension | ||
# to X. There's nothing too special about this, and we could choose | ||
# any random extension to X's basis, but this is a deterministic choice. | ||
# | ||
# https://epubs.siam.org/doi/abs/10.1137/0725014 | ||
# https://www.jstage.jst.go.jp/article/ipsjdc/2/0/2_0_298/_article | ||
n, k = X.shape | ||
Xupper, Xlower = jnp.split(X, [k], axis=0) | ||
u, s, vt = jnp.linalg.svd(Xupper) | ||
y = jnp.concatenate([Xupper + u @ vt, Xlower], axis=0) | ||
other = jnp.concatenate( | ||
[jnp.eye(m, dtype=X.dtype), | ||
jnp.zeros((n - k - m, m), dtype=X.dtype)], axis=0) | ||
w = _mm(y, vt.T * ((2 * (1 + s)) ** (-1/2))[jnp.newaxis, :]) | ||
h = -2 * jnp.linalg.multi_dot( | ||
[w, w[k:, :].T, other], precision=jax.lax.Precision.HIGHEST) | ||
return h.at[k:].add(other) |
Oops, something went wrong.