In [1]:
import functools

import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax.experimental.sparse as jsparse
import time
import scipy
import numpy
from typing import NamedTuple
from typing import Union
from jax import Array

def _error_num_matvecs(num, maxval, minval):
    msg1 = f"Parameter 'num_matvecs'={num} exceeds the acceptable range. "
    msg2 = f"Expected: {minval} <= num_matvecs <= {maxval}."
    return msg1 + msg2

class _DecompResult(NamedTuple):
    # If an algorithm returns a single Q, place it here.
    # If it returns multiple Qs, stack them
    # into a tuple and place them here.
    Q_tall: Union[Array, tuple[Array, ...]]

    # If an algorithm returns a materialized matrix,
    # place it here. If it returns a sparse representation
    # (e.g. two vectors representing diagonals), place it here
    J_small: Union[Array, tuple[Array, ...]]

    residual: Array
    init_length_inv: Array



def _hessenberg_forward(matvec, num_matvecs, v, *params, reortho: str, mgs=False):
    if num_matvecs < 0 or num_matvecs > len(v):
        msg = _error_num_matvecs(num_matvecs, maxval=len(v), minval=0)
        raise ValueError(msg)

    # Initialise the variables
    (n,), k = jnp.shape(v), num_matvecs
    Q = jnp.zeros((n, k), dtype=v.dtype)
    H = jnp.zeros((k, k), dtype=v.dtype)
    initlength = jnp.sqrt(jnp.inner(v, v))
    init = (Q, H, v, initlength)

    if num_matvecs == 0:
        return _DecompResult(
            Q_tall=Q, J_small=H, residual=v, init_length_inv=1 / initlength
        )

    # Fix the step function
    if mgs:
        Q = Q.at[:, 0].set(v / initlength)
        init = (Q, H, v, initlength)
        def forward_step(i, val):
            return _hessenberg_forward_step_mgs(*val, matvec, *params, idx=i, reortho=reortho)
    else:
        def forward_step(i, val):
            return _hessenberg_forward_step(*val, matvec, *params, idx=i, reortho=reortho)


    # Loop and return
    Q, H, v, _length = jax.lax.fori_loop(0, k, forward_step, init)
    return _DecompResult(
        Q_tall=Q, J_small=H, residual=v, init_length_inv=1 / initlength
    )


def _hessenberg_forward_step(Q, H, v, length, matvec, *params, idx, reortho: str):
    # Save
    v /= length
    Q = Q.at[:, idx].set(v)

    # Evaluate
    v = matvec(v, *params)

    # Orthonormalise
    h = Q.T @ v
    v = v - Q @ h

    # Re-orthonormalise
    if reortho != "none":
        v = v - Q @ (Q.T @ v)

    # Read the length
    length = jnp.sqrt(jnp.inner(v, v))

    # Save
    h = h.at[idx + 1].set(length)
    H = H.at[:, idx].set(h)

    return Q, H, v, length

def _hessenberg_forward_step_mgs(Q, H, v, length, matvec, *params, idx, reortho: str):
    # w = Q[:, idx]
    w = matvec(v, *params)

    def body_fun(j, val):
        w, H = val
        v = Q[:, j]
        ip = jnp.dot(v, w)
        H = H.at[j, idx].add(ip)
        w = w - ip * v
        return w, H

    w, H = jax.lax.fori_loop(0, idx + 1, body_fun, (w, H))
    eta = jnp.linalg.norm(w)
    H = H.at[idx + 1, idx].set(eta)
    w = w / eta
    Q = Q.at[:, idx + 1].set(w)
    return Q, H, w, eta

def grcar(n: int, k: int = 3):
    """Generate a Grcar matrix.

    This matrix can be generated in Matlab by calling `gallery('grcar', n, k)`.
    """
    subdiag = -1 * numpy.ones(n - 1)
    diag = numpy.ones(n)
    superdiags = [numpy.ones(n - i) for i in range(1, k + 1)]
    return scipy.sparse.diags([subdiag] + [diag] + superdiags, offsets=range(-1, k + 1), format="csr")


In [2]:
n = 5000
rng = jax.random.PRNGKey(77)
numpy_rng = numpy.random.default_rng(777)
A_scipy = grcar(n)
A = jsparse.BCOO.from_scipy_sparse(A_scipy)
b = jax.random.normal(rng, (n,), dtype=jnp.float64)
b_scipy = numpy.array(b)
b_length = numpy.linalg.norm(b)

@jax.jit
def large_matvec(v):
    """Evaluate a matrix-vector product."""
    return A @ v

In [3]:
def trial_results(ns: list):
    results = {n: {
        "ts": [],
        "loss of orthogonality": [],
    } for n in ns}
    return results

In [4]:
num_matvecs = 900
_hessenberg_forward_cgs = functools.partial(_hessenberg_forward, num_matvecs=900, matvec=large_matvec, reortho="full", mgs=False)
_hessenberg_forward_cgs = jax.jit(_hessenberg_forward_cgs)
for _ in range(3):
    start = time.time()
    arnoldi_avoiding: _DecompResult = _hessenberg_forward_cgs(v=b)
    jax.block_until_ready(arnoldi_avoiding)
    end = time.time()
    Q = arnoldi_avoiding.Q_tall
    ortho = jnp.linalg.norm(jnp.eye(num_matvecs, num_matvecs) - Q.T @ Q)
    print(f"Time taken: {end - start:.6f} seconds, n={num_matvecs}, orthogonality: {ortho}")

Time taken: 3.045705 seconds, n=900, orthogonality: 1.628267895911933e-14
Time taken: 0.714611 seconds, n=900, orthogonality: 1.6087962625751466e-14
Time taken: 0.722825 seconds, n=900, orthogonality: 1.6128700367225805e-14


In [5]:
num_matvecss = [200, 900]

def trial(num_matvecss):
  ortho_mgs = trial_results(num_matvecss)
  ortho_cgs = trial_results(num_matvecss)
  ortho_cgs_reoproj = trial_results(num_matvecss)

  for num_matvecs in num_matvecss:
    _hessenberg_forward_mgs = functools.partial(_hessenberg_forward, num_matvecs=num_matvecs, matvec=large_matvec, reortho="none", mgs=True)
    _hessenberg_forward_mgs = jax.jit(_hessenberg_forward_mgs)
    for _ in range(5):
        start = time.time()
        arnoldi_avoiding: _DecompResult = _hessenberg_forward_mgs(v=b)
        jax.block_until_ready(arnoldi_avoiding)
        end = time.time()
        Q = arnoldi_avoiding.Q_tall
        ortho = jnp.linalg.norm(jnp.eye(num_matvecs, num_matvecs) - Q.T @ Q)
        print(f"Time taken: {end - start:.6f} seconds, n={num_matvecs}, orthogonality: {ortho}")
        ortho_mgs[num_matvecs]["ts"].append(end - start)
        ortho_mgs[num_matvecs]["loss of orthogonality"].append(ortho)

    _hessenberg_forward_cgs = functools.partial(_hessenberg_forward, num_matvecs=num_matvecs, matvec=large_matvec, reortho="none", mgs=False)
    _hessenberg_forward_cgs = jax.jit(_hessenberg_forward_cgs)
    for _ in range(5):
        start = time.time()
        arnoldi_avoiding: _DecompResult = _hessenberg_forward_cgs(v=b)
        jax.block_until_ready(arnoldi_avoiding)
        end = time.time()
        Q = arnoldi_avoiding.Q_tall
        ortho = jnp.linalg.norm(jnp.eye(num_matvecs, num_matvecs) - Q.T @ Q)
        print(f"Time taken: {end - start:.6f} seconds, n={num_matvecs}, orthogonality: {ortho}")
        ortho_cgs[num_matvecs]["ts"].append(end - start)
        ortho_cgs[num_matvecs]["loss of orthogonality"].append(ortho)

    _hessenberg_forward_cgs_reoproj = functools.partial(_hessenberg_forward, num_matvecs=num_matvecs, matvec=large_matvec, reortho="full", mgs=False)
    _hessenberg_forward_cgs_reoproj = jax.jit(_hessenberg_forward_cgs_reoproj)
    for _ in range(5):
        start = time.time()
        arnoldi_avoiding: _DecompResult = _hessenberg_forward_cgs_reoproj(v=b)
        jax.block_until_ready(arnoldi_avoiding)
        end = time.time()
        Q = arnoldi_avoiding.Q_tall
        ortho = jnp.linalg.norm(jnp.eye(num_matvecs, num_matvecs) - Q.T @ Q)
        print(f"Time taken: {end - start:.6f} seconds, n={num_matvecs}, orthogonality: {ortho}")
        ortho_cgs_reoproj[num_matvecs]["ts"].append(end - start)
        ortho_cgs_reoproj[num_matvecs]["loss of orthogonality"].append(ortho)
  return ortho_mgs, ortho_cgs, ortho_cgs_reoproj

In [6]:
ortho_mgs, ortho_cgs, ortho_cgs_reoproj = trial(num_matvecss)

print(ortho_mgs)
print(ortho_cgs)
print(ortho_cgs_reoproj)

Time taken: 1.161742 seconds, n=200, orthogonality: 5.726874141049164e-11
Time taken: 0.538039 seconds, n=200, orthogonality: 7.043319222224402e-11
Time taken: 0.541878 seconds, n=200, orthogonality: 5.4372677693131525e-11
Time taken: 0.551504 seconds, n=200, orthogonality: 5.5502272622528034e-11
Time taken: 0.563922 seconds, n=200, orthogonality: 6.59442356853385e-11
Time taken: 0.674655 seconds, n=200, orthogonality: 9.662870653082427e-10
Time taken: 0.023484 seconds, n=200, orthogonality: 5.025793351821561e-10
Time taken: 0.023180 seconds, n=200, orthogonality: 6.31167913636217e-11
Time taken: 0.023274 seconds, n=200, orthogonality: 5.603260074889483e-10
Time taken: 0.023574 seconds, n=200, orthogonality: 1.5587808643715953e-09
Time taken: 0.786281 seconds, n=200, orthogonality: 4.638194678692581e-15
Time taken: 0.043816 seconds, n=200, orthogonality: 4.559526024053206e-15
Time taken: 0.042853 seconds, n=200, orthogonality: 4.513366944241335e-15
Time taken: 0.042662 seconds, n=200, 

In [7]:
with jax.default_device(jax.devices('cpu')[0]):
  ortho_mgs, ortho_cgs, ortho_cgs_reoproj = trial(num_matvecss)

  print(ortho_mgs)
  print(ortho_cgs)
  print(ortho_cgs_reoproj)

Time taken: 1.791687 seconds, n=200, orthogonality: 3.1320675616833106e-10
Time taken: 1.109135 seconds, n=200, orthogonality: 3.1320675616833106e-10
Time taken: 1.110461 seconds, n=200, orthogonality: 3.1320675616833106e-10
Time taken: 1.107299 seconds, n=200, orthogonality: 3.1320675616833106e-10
Time taken: 1.106515 seconds, n=200, orthogonality: 3.1320675616833106e-10
Time taken: 1.744044 seconds, n=200, orthogonality: 8.590362726625322e-09
Time taken: 1.420714 seconds, n=200, orthogonality: 8.590362726625322e-09
Time taken: 1.782701 seconds, n=200, orthogonality: 8.590362726625322e-09
Time taken: 1.838315 seconds, n=200, orthogonality: 8.590362726625322e-09
Time taken: 1.773728 seconds, n=200, orthogonality: 8.590362726625322e-09
Time taken: 2.177146 seconds, n=200, orthogonality: 2.874116232194062e-14
Time taken: 1.657833 seconds, n=200, orthogonality: 2.874116232194062e-14
Time taken: 1.667096 seconds, n=200, orthogonality: 2.874116232194062e-14
Time taken: 1.682868 seconds, n=2