Skip to content

Commit

Permalink
Add initial LOBPCG top-k eigenvalue solver (google#3112)
Browse files Browse the repository at this point in the history
This initial version is f32-only for accelerators, since it relies on an eigh call (which itself is f32 at most) in its inner loop.

For details, see jax.experimental.linalg.standard_lobpcg documentation.
  • Loading branch information
Vladimir Feinberg committed Jun 3, 2022
1 parent d43cb36 commit b57b52c
Show file tree
Hide file tree
Showing 3 changed files with 563 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/experimental/sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,3 +259,5 @@
sparsify as sparsify,
SparseTracer as SparseTracer,
)

from jax.experimental.sparse import linalg
345 changes: 345 additions & 0 deletions jax/experimental/sparse/linalg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,345 @@
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sparse linear algebra routines."""

from typing import Union
import functools

import jax
import jax.numpy as jnp


def lobpcg_standard(
A: jnp.ndarray,
X: jnp.ndarray,
m: int = 100,
tol: Union[jnp.ndarray, float, None] = None):
"""Compute the top-k standard eigenvalues using the LOBPCG routine.
LOBPCG [1] stands for Locally Optimal Block Preconditioned Conjugate Gradient.
The method enables finding top-k eigenvectors in an accelerator-friendly
manner.
This initial experimental version has several caveats.
- Only the standard eigenvalue problem `A U = lambda U` is supported,
general eigenvalues are not.
- Gradient code is not available.
- f64 will only work where jnp.linalg.eigh is supported for that type.
- Finding the smallest eigenvectors is not yet supported. As a result,
we don't yet support preconditioning, which is mostly needed for this
case.
The implementation is based on [2] and [3]; however, we deviate from these
sources in several ways to improve robustness or facilitate implementation:
- Despite increased iteration cost, we always maintain an orthonormal basis
for the block search directions.
- We change the convergence criterion; see the `tol` argument.
- Soft locking is not implemented.
[1] http://ccm.ucdenver.edu/reports/rep149.pdf
[2] https://arxiv.org/abs/1704.07458
[3] https://arxiv.org/abs/0705.2626
Args:
A : An `(n, n)` array representing a square matrix.
X : An `(n, k)` array representing the initial search directions for the `k`
desired top eigenvectors. This need not be orthogonal, but must be
linearly independent.
m : Maximum integer iteration count; LOBPCG will only ever explore (a
subspace of) the Krylov basis `{X, A X, A^2 X, ..., A^m X}`.
tol : A float convergence tolerance; an eigenpair `(lambda, v)` is converged
when its residual L2 norm `r = |A v - lambda v|` is below
`tol * 10 * n * (lambda + |A v|)`, which
roughly estimates the worst-case floating point error for an ideal
eigenvector. If all `k` eigenvectors satisfy the tolerance
comparison, then LOBPCG exits early. If left as None, then this is set
to the float epsilon of `A.dtype`.
debug : A boolean indicating whether to return additional diagnostics.
Returns:
`theta, U, i [, diagnostics]`, where `theta` is a `(k,)` array
of eigenvalues, `U` is a `(n, k)` array of eigenvectors, `i` is the
number of iterations performed, and `diagnostics` is a dictionary with debug
information, which is only returned if `debug` is set to true.
Raises:
ValueError : if `A,X` dtypes or `n` dimensions do not match, or `k` is too
large (only `k * 5 < n` supported), or `k == 0`.
"""
return _lobpcg_standard(A, X, m, tol, debug=False)

@functools.partial(jax.jit, static_argnames=['m', 'debug'])
def _lobpcg_standard(
A: jnp.ndarray,
X: jnp.ndarray,
m: int,
tol: Union[jnp.ndarray, float, None],
debug: bool = False):
"""Computes lobpcg_standard(), possibly with debug diagnostics."""

# TODO(vladf): support mixed_precision flag, which allows f64 Rayleigh-Ritz
# with f32 inputs.
mixed_precision = False

n, k = X.shape
dt = X.dtype

_check_inputs(A, X)

if tol is None:
tol = jnp.finfo(dt).eps

X = _svqb(X, mixed_precision)
P = _extend_basis(X, X.shape[1])

# We maintain X, our current list of best eigenvectors,
# P, our search direction, and
# R, our residuals, in a large joint array XPR, column-stacked, so (n, 3*k).

AX = _mm(A, X)
theta = jnp.sum(X * AX, axis=0, keepdims=True)
R = AX - theta * X

def cond(state):
i, _X, _P, _R, converged, _ = state
return jnp.logical_and(i < m, converged < k)

def body(state):
i, X, P, R, _, theta = state
# Invariants: X, P, R kept orthonormal
# Some R, P columns may be 0, but not X.

# TODO(vladf): support preconditioning for bottom-k eigenvectors
# if M is not None:
# R = M(R)

# residual basis selection
R = _project_out(jnp.concatenate((X, P), axis=1), R, mixed_precision)
XPR = jnp.concatenate((X, P, R), axis=1)

# Projected eigensolve.
theta, Q = _rayleigh_ritz_orth(A, XPR, mixed_precision)

# Eigenvector X extraction
B = Q[:, :k]
normB = jnp.linalg.norm(B, ord=2, axis=0, keepdims=True)
B /= normB
X = _mm(XPR, B)
normX = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True)
X /= normX

# Difference terms P extraction
#
# In next step of LOBPCG, naively, we'd set
# P = S[:, k:] @ Q[k:, :k] to achieve span(X, P) == span(X, previous X)
# (this is not obvious, see section 4 of [1]).
#
# Instead we orthogonalize concat(0, Q[k:, :k]) against Q[:, :k]
# in the standard basis before mapping with XPR. Since XPR is itself
# orthonormal, the resulting directions are themselves orthonormalized.
#
# [2] leverages Q's existing orthogonality to derive
# an analytic expression for this value based on the quadrant Q[:k,k:]
# (see section 4.2 of [2]).
q, _ = jnp.linalg.qr(Q[:k, k:].T)
diff_rayleigh_ortho = _mm(Q[:, k:], q)
P = _mm(XPR, diff_rayleigh_ortho)
normP = jnp.linalg.norm(P, ord=2, axis=0, keepdims=True)
P /= jnp.where(normP == 0, 1.0, normP)

# Compute new residuals.
AX = _mm(A, X)
R = AX - theta[jnp.newaxis, :k] * X
resid_norms = jnp.linalg.norm(R, ord=2, axis=0)

# I tried many variants of hard and soft locking [3]. All of them seemed
# to worsen performance relative to no locking.
#
# Further, I found a more expermental convergence formula compared to what
# is suggested in the literature, loosely based on floating-point
# expectations.
#
# [2] discusses various strategies for this in Sec 5.3. The solution
# they end up with, which estimates operator norm |A| via Gaussian
# products, was too crude in practice (and overly-lax). The Gaussian
# approximation seems like an estimate of the average eigenvalue.
#
# Instead, we test convergence via self-consistency of the eigenpair
# i.e., the residual norm |r| should be small, relative to the floating
# point error we'd expect from computing just the residuals given
# candidate vectors.
#
# sqrt(n) - random walk error from AX multiply
reltol = jnp.linalg.norm(AX, ord=2, axis=0) + theta[:k]
reltol *= n
# Allow some margin for a few element-wise operations.
reltol *= 10
res_converged = resid_norms < tol * reltol
converged = jnp.sum(res_converged)

new_state = i + 1, X, P, R, converged, theta[jnp.newaxis, :k]
if debug:
diagnostics = _generate_diagnostics(
XPR, X, P, R, theta, converged, resid_norms / reltol)
new_state = (new_state, diagnostics)
return new_state

converged = 0
state = (0, X, P, R, converged, theta)
if debug:
state, diagnostics = jax.lax.scan(
lambda state, _: body(state), state, xs=None, length=m)
else:
state = jax.lax.while_loop(cond, body, state)
i, X, _P, _R, _converged, theta = state

if debug:
return theta[0, :], X, i, diagnostics
return theta[0, :], X, i


def _check_inputs(A, X):
n, k = X.shape
dt = X.dtype

if k == 0:
raise ValueError(f'must have search dim > 0, got {k}')

if A.dtype != dt:
raise ValueError(f'A, X must have same dtypes (were {A.dtype}, {dt})')

if A.shape != (n, n):
raise ValueError(f'A must be ({n}, {n}) matrix A, got {A.shape}')

if k * 5 >= n:
raise ValueError(f'expected search dim * 5 < matrix dim (got {k * 5}, {n})')


def _mm(a, b, precision=jax.lax.Precision.HIGHEST):
return jax.lax.dot(a, b, (precision, precision))

def _generate_diagnostics(prev_XPR, X, P, R, theta, converged, adj_resid):
k = X.shape[1]
assert X.shape == P.shape

diagdiag = lambda x: jnp.diag(jnp.diag(x))
abserr = lambda x: jnp.abs(x).sum() / (k ** 2)

XTX = _mm(X.T, X)
DX = diagdiag(XTX)
orthX = abserr(XTX - DX)

PTP = _mm(P.T, P)
DP = diagdiag(PTP)
orthP = abserr(PTP - DP)

PX = abserr(X.T @ P)

prev_basis = prev_XPR.shape[1] - jnp.sum(jnp.all(prev_XPR == 0.0, axis=0))

return {
'basis rank': prev_basis,
'X zeros': jnp.sum(jnp.all(X == 0.0, axis=0)),
'P zeros': jnp.sum(jnp.all(P == 0.0, axis=0)),
'lambda history': theta[:k],
'residual history': jnp.linalg.norm(R, axis=0, ord=2),
'converged': converged,
'adjusted residual max': jnp.max(adj_resid),
'adjusted residual p50': jnp.median(adj_resid),
'adjusted residual min': jnp.min(adj_resid),
'X orth': orthX,
'P orth': orthP,
'P.X': PX}

def _eigh_possibly_mixed(A, mixed_precision):
assert not mixed_precision, 'mixed precision not yet supported'
w, V = jnp.linalg.eigh(A)
return w[::-1], V[:, ::-1]


def _svqb(X, mixed_precision):
# https://sdm.lbl.gov/~kewu/ps/45577.html

norms = jnp.linalg.norm(X, ord=2, axis=0, keepdims=True)
X /= jnp.where(norms == 0, 1.0, norms)

inner = _mm(X.T, X)

w, V = _eigh_possibly_mixed(inner, mixed_precision)

tau = jnp.finfo(X.dtype).eps * w[0]
padded = jnp.maximum(w, tau)
sqrted = jnp.where(tau > 0, padded, 1.0) ** (-0.5)
scaledV = V * sqrted[jnp.newaxis, :]
orthoX = _mm(X, scaledV)

keep = ((w > tau) * (jnp.diag(inner) > 0.0))[jnp.newaxis, :]
orthoX *= keep
norms = jnp.linalg.norm(orthoX, ord=2, axis=0, keepdims=True)
keep *= norms > 0.0
orthoX /= jnp.where(keep, norms, 1.0)
return orthoX


def _project_out(basis, U, mixed_precision):
# "twice is enough" from shoyer's reference:
# http://slepc.upv.es/documentation/reports/str1.pdf

for _ in range(2):
U -= _mm(basis, _mm(basis.T, U))
for _ in range(2):
U = _svqb(U, mixed_precision)

return U


def _rayleigh_ritz_orth(A, S, mixed_precision):
# Classical Rayleigh-Ritz returns w, V satisfying
# (1) S.T A S @ V ~= w * V
# such that (2) V is (S.T S)-orthonormal.
# https://www.netlib.org/lapack/lug/node54.html
#
# Usually it requires solving the complicated standard eigensystem
# U^-T S^T A S U^-1 @ Q = w * Q and then backsolving V = U^-1 Q,
# but if S is standard orthonormal then we just need to find
# eigenvalues of S.T A S.

SAS = _mm(S.T, _mm(A, S))

# Solve the projected subsytem.
# If we could tell to eigh to stop after first k, we would.
return _eigh_possibly_mixed(SAS, mixed_precision)


def _extend_basis(X, m):
# Use a block householder reflector to generate orthogonal extension
# to X. There's nothing too special about this, and we could choose
# any random extension to X's basis, but this is a deterministic choice.
#
# https://epubs.siam.org/doi/abs/10.1137/0725014
# https://www.jstage.jst.go.jp/article/ipsjdc/2/0/2_0_298/_article
n, k = X.shape
Xupper, Xlower = jnp.split(X, [k], axis=0)
u, s, vt = jnp.linalg.svd(Xupper)
y = jnp.concatenate([Xupper + u @ vt, Xlower], axis=0)
other = jnp.concatenate(
[jnp.eye(m, dtype=X.dtype),
jnp.zeros((n - k - m, m), dtype=X.dtype)], axis=0)
w = _mm(y, vt.T * ((2 * (1 + s)) ** (-1/2))[jnp.newaxis, :])
h = -2 * jnp.linalg.multi_dot(
[w, w[k:, :].T, other], precision=jax.lax.Precision.HIGHEST)
return h.at[k:].add(other)
Loading

0 comments on commit b57b52c

Please sign in to comment.