In [2]:
"""Regression tests comparing LOBPCG against SciPy's Lanczos (eigsh).

The goal is to validate that our GPU-friendly LOBPCG implementation
matches a classical Lanczos solver when applied to the Hessian of a
small network. We check both the randomly initialised model and the
same model after a short synthetic training run.
"""

from __future__ import annotations

import numpy as np
import pytest
import torch
from torch import nn

import sys
from pathlib import Path
# Add parent directory to Python path to import from utils
sys.path.insert(0, str(Path('..')))

from utils.lobpcg_old import torch_lobpcg


from utils.measure import create_hessian_vector_product
from importlib import reload
from utils.lobpcg import torch_lobpcg as torch_lobpcg_rewrite

try:
    from scipy.sparse.linalg import LinearOperator, eigsh
except ImportError as exc:  # pragma: no cover - pytest will surface the failure
    raise RuntimeError("SciPy is required for lobpcg vs Lanczos tests") from exc

In [3]:
def _lobpcg_vs_lanczos_on_matrix(A: torch.Tensor, k: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Helper that compares eigenvalues and eigenvectors for a dense symmetric matrix."""

    assert A.shape[0] == A.shape[1], "Matrix must be square"
    n = A.shape[0]
    dtype = A.dtype
    device = A.device

    def matmul(vec: torch.Tensor) -> torch.Tensor:
        if vec.ndim == 1:
            return A @ vec
        return A @ vec

    init = torch.randn(n, k, device=device, dtype=dtype)
    eigvals_lobpcg, eigvecs_lobpcg, _ = torch_lobpcg_rewrite(matmul, init, max_iter=100, tol=1e-10)

    np_dtype = np.float64 if dtype == torch.float64 else np.float32

    def scipy_matvec(vec: np.ndarray) -> np.ndarray:
        torch_vec = torch.from_numpy(vec).to(device=device, dtype=dtype)
        result = (A @ torch_vec).detach().cpu().numpy().astype(np_dtype, copy=False)
        return result

    linear_op = LinearOperator(
        shape=(n, n),
        matvec=scipy_matvec,
        dtype=np_dtype,
    )
    eigvals_lanczos, eigvecs_lanczos = eigsh(linear_op, k=k, which="LM", tol=1e-10, maxiter=500)
    
    # Sort eigenvalues and eigenvectors in descending order
    sort_idx = np.argsort(eigvals_lanczos)[::-1]
    eigvals_lanczos = eigvals_lanczos[sort_idx]
    eigvecs_lanczos = eigvecs_lanczos[:, sort_idx]

    return (
        eigvals_lobpcg.detach().cpu(), 
        eigvecs_lobpcg.detach().cpu(),
        torch.from_numpy(eigvals_lanczos.copy()).to(dtype),
        torch.from_numpy(eigvecs_lanczos.copy()).to(dtype)
    )


def test_lobpcg_matches_lanczos_on_random_matrix() -> None:
    """Sanity check LOBPCG against Lanczos on a small random SPD matrix."""
    torch.manual_seed(2352)
    n = 16
    k = 3
    base = torch.randn(n, n, dtype=torch.float32)
    A = base @ base.T + 0.5 * torch.eye(n, dtype=torch.float32)

    lobpcg_vals, lobpcg_vecs, lanczos_vals, lanczos_vecs = _lobpcg_vs_lanczos_on_matrix(A, k)

    print("LOBPCG eigenvalues:", lobpcg_vals)
    print("Lanczos eigenvalues:", lanczos_vals)

    # Test eigenvalue agreement
    # torch.testing.assert_close(lobpcg_vals, lanczos_vals, rtol=1e-6, atol=1e-8)
    
    # Test that eigenvectors satisfy A*v = λ*v for both methods
    print("\nVerifying LOBPCG eigenvectors:")
    for i in range(k):
        v = lobpcg_vecs[:, i]
        λ = lobpcg_vals[i]
        residual = torch.norm(A @ v - λ * v)
        print(f"  Eigenvector {i}: ||Av - λv|| = {residual:.2e}")
        # assert residual < 1e-4, f"LOBPCG eigenvector {i} has large residual: {residual}"
    
    print("\nVerifying Lanczos eigenvectors:")
    for i in range(k):
        v = lanczos_vecs[:, i]
        λ = lanczos_vals[i]
        residual = torch.norm(A @ v - λ * v)
        print(f"  Eigenvector {i}: ||Av - λv|| = {residual:.2e}")
        assert residual < 1e-4, f"Lanczos eigenvector {i} has large residual: {residual}"


In [4]:
test_lobpcg_matches_lanczos_on_random_matrix()

LOBPCG eigenvalues: tensor([60.2970, 45.4289, 34.1045])
Lanczos eigenvalues: tensor([60.2970, 45.4289, 34.1045])

Verifying LOBPCG eigenvectors:
  Eigenvector 0: ||Av - λv|| = 1.07e-05
  Eigenvector 1: ||Av - λv|| = 1.33e-04
  Eigenvector 2: ||Av - λv|| = 1.02e-04

Verifying Lanczos eigenvectors:
  Eigenvector 0: ||Av - λv|| = 1.45e-05
  Eigenvector 1: ||Av - λv|| = 1.19e-05
  Eigenvector 2: ||Av - λv|| = 9.16e-06


In [4]:
dt = torch.float32
torch.finfo(dt).eps

1.1920928955078125e-07