# LOBPCG Rewrite vs Legacy Timing

This notebook exists to compare the speeds of lobpcg, lobpcg_old and lancsoz (of scipy)

This notebook trains a CIFAR-10 MLP for 500 SGD steps and, every 64 steps, measures the top Hessian eigenvalues using the rewritten and legacy LOBPCG implementations alongside SciPy's `eigsh`. All measurements reuse cached eigenvectors to simulate warm restarts on a warm GPU, and the runtime for each solver is recorded for comparison.

In [8]:
import os
import time
from pathlib import Path

import numpy as np
import torch
from torch import nn

try:
    import pandas as pd
except ImportError:
    pd = None

from scipy.sparse.linalg import LinearOperator, eigsh

from utils.data import prepare_dataset
from utils.nets import prepare_net_dataset_specific, initialize_net, prepare_optimizer, SquaredLoss
from utils.measure import create_hessian_vector_product
from utils.lobpcg import torch_lobpcg as torch_lobpcg_rewrite
from utils.lobpcg_old import torch_lobpcg as torch_lobpcg_old


In [2]:
torch.manual_seed(111)
np.random.seed(111)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True

sync_cuda = torch.cuda.synchronize if torch.cuda.is_available() else (lambda: None)

DATASET_ROOT = Path(os.environ.get("DATASETS", "/scratch/gpfs/andreyev/datasets/"))

dataset_name = "cifar10"
model_name = "mlp"
num_data = 4096
batch_size = 4096
measurement_batch_size = 4096
num_training_steps = 500
measurement_every = 64
num_eigenvalues = 3
lobpcg_max_iter = 200
lobpcg_tol = 1e-9
scipy_max_iter = 400
scipy_tol = 1e-5
learning_rate = 0.01
init_scale = 0.2
dataset_seed = 111
init_seed = 8312
classes = None

print(f"Using device: {device}")

X_train, Y_train, X_test, Y_test = prepare_dataset(
    dataset_name,
    DATASET_ROOT,
    num_data,
    classes,
    dataset_seed=dataset_seed,
    loss_type="mse",
)

net = prepare_net_dataset_specific(model_name, dataset_name).to(device)
initialize_net(net, scale=init_scale, seed=init_seed)
optimizer = prepare_optimizer(net, lr=learning_rate, momentum=None, adam=False)
loss_fn = SquaredLoss()

X_train = X_train.to(device)
Y_train = Y_train.to(device)

measurement_batch_size = min(measurement_batch_size, X_train.shape[0])
train_size = X_train.shape[0]
param_count = sum(p.numel() for p in net.parameters())

print(f"Training samples: {train_size}")
print(f"Parameter count: {param_count}")


Using device: cuda
Training samples: 4096
Parameter count: 1841162


In [3]:
measurement_caches = {
    "lobpcg_rewrite": {"vectors": None},
    "lobpcg_old": {"vectors": None},
    "scipy": {"vector": None},
}

def sample_batch(inputs, targets, size):
    idx = torch.randint(0, inputs.shape[0], (size,), device=inputs.device)
    return inputs[idx], targets[idx]

def run_lobpcg_solver(lobpcg_fn, cache_key, matvec):
    cache = measurement_caches[cache_key]
    if cache["vectors"] is not None:
        init_vecs = cache["vectors"]
    else:
        init_vecs = torch.randn(param_count, num_eigenvalues, device=device)
    init_vecs = init_vecs.clone()
    sync_cuda()
    start = time.perf_counter()
    eigenvalues, eigenvectors, iterations = lobpcg_fn(
        matvec, init_vecs, max_iter=500, tol=lobpcg_tol,
        
    )
    sync_cuda()
    elapsed = time.perf_counter() - start
    cache["vectors"] = eigenvectors.detach()
    return {
        "time_sec": float(elapsed),
        "iterations": int(iterations),
        "eigvals": eigenvalues.detach().cpu().numpy(),
    }

def run_scipy_solver(matvec):
    cache = measurement_caches["scipy"]
    dtype = np.float32

    def mv(vec):
        vec_t = torch.from_numpy(np.asarray(vec, dtype=dtype)).to(device)
        result = matvec(vec_t)
        sync_cuda()
        return result.detach().cpu().numpy().astype(dtype)

    operator = LinearOperator((param_count, param_count), matvec=mv, dtype=dtype)
    v0 = cache["vector"]
    sync_cuda()
    start = time.perf_counter()
    eigenvalues, eigenvectors = eigsh(
        operator,
        k=num_eigenvalues,
        which="LA",
        v0=v0,
        maxiter=scipy_max_iter,
        tol=scipy_tol,
    )
    elapsed = time.perf_counter() - start
    order = np.argsort(eigenvalues)[::-1]
    eigenvalues = eigenvalues[order]
    eigenvectors = eigenvectors[:, order]
    cache["vector"] = eigenvectors[:, 0].astype(dtype)
    return {
        "time_sec": float(elapsed),
        "iterations": None,
        "eigvals": eigenvalues,
    }

def measure_eigenpairs(step):
    net.eval()
    with torch.enable_grad():
        meas_x, meas_y = sample_batch(X_train, Y_train, measurement_batch_size)
        loss = loss_fn(net(meas_x), meas_y)
        matvec = create_hessian_vector_product(loss, net)
        results = {
            "scipy_eigsh": run_scipy_solver(matvec),
            "lobpcg_rewrite": run_lobpcg_solver(torch_lobpcg_rewrite, "lobpcg_rewrite", matvec),
            "lobpcg_old": run_lobpcg_solver(torch_lobpcg_old, "lobpcg_old", matvec),
        }
    net.train()
    del loss, matvec
    return results


In [4]:
measurement_log = []
net.train()

for step in range(1, num_training_steps + 1):
    batch_x, batch_y = sample_batch(X_train, Y_train, batch_size)
    optimizer.zero_grad(set_to_none=True)
    preds = net(batch_x)
    loss = loss_fn(preds, batch_y)
    loss.backward()
    optimizer.step()

    if step % measurement_every == 0:
        measurements = measure_eigenpairs(step)
        measurement_log.append({"step": step, **measurements})
        rewrite_time = measurement_log[-1]["lobpcg_rewrite"]["time_sec"]
        old_time = measurement_log[-1]["lobpcg_old"]["time_sec"]
        scipy_time = measurement_log[-1]["scipy_eigsh"]["time_sec"]
        print(
            f"Step {step:4d} | loss {loss.item():.4f} | rewrite {rewrite_time:.4f}s | old {old_time:.4f}s | scipy {scipy_time:.4f}s"
        )

if num_training_steps % measurement_every != 0:
    measurements = measure_eigenpairs(num_training_steps)
    measurement_log.append({"step": num_training_steps, **measurements})
    rewrite_time = measurement_log[-1]["lobpcg_rewrite"]["time_sec"]
    old_time = measurement_log[-1]["lobpcg_old"]["time_sec"]
    scipy_time = measurement_log[-1]["scipy_eigsh"]["time_sec"]
    print(
        f"Step {num_training_steps:4d} | loss {loss.item():.4f} | rewrite {rewrite_time:.2f}s | old {old_time:.2f}s | scipy {scipy_time:.2f}s"
    )


Step   64 | loss 0.8510 | rewrite 1.8432s | old 2.5861s | scipy 2.0807s
Step  128 | loss 0.8191 | rewrite 0.4384s | old 0.9099s | scipy 1.1653s
Step  192 | loss 0.8050 | rewrite 0.6572s | old 0.9095s | scipy 2.0291s
Step  256 | loss 0.7789 | rewrite 0.4384s | old 0.6306s | scipy 2.0197s
Step  320 | loss 0.7606 | rewrite 0.2195s | old 0.3517s | scipy 1.1563s
Step  384 | loss 0.7554 | rewrite 0.4384s | old 0.6304s | scipy 1.1573s
Step  448 | loss 0.7364 | rewrite 0.2197s | old 0.3512s | scipy 1.1583s
Step  500 | loss 0.7269 | rewrite 0.44s | old 0.35s | scipy 1.16s


In [5]:
summary_rows = []
for entry in measurement_log:
    step = entry["step"]
    ref_vals = np.sort(np.asarray(entry["scipy_eigsh"]["eigvals"]))[::-1]
    for key, label in [("lobpcg_rewrite", "torch_lobpcg_rewrite"), ("lobpcg_old", "torch_lobpcg_old"), ("scipy_eigsh", "scipy_eigsh")]:
        result = entry[key]
        eigvals = np.sort(np.asarray(result["eigvals"]))[::-1]
        row = {
            "step": step,
            "method": label,
            "time_sec": result["time_sec"],
            "iterations": result["iterations"],
            "eigvals": eigvals.tolist(),
        }
        if key != "scipy_eigsh":
            row["max_abs_diff_vs_scipy"] = float(np.max(np.abs(eigvals - ref_vals)))
            row["max_rel_diff_vs_scipy"] = float(np.max(np.abs((eigvals - ref_vals) / ref_vals)))
        else:
            row["max_abs_diff_vs_scipy"] = 0.0
        summary_rows.append(row)

if pd is not None:
    df = pd.DataFrame(summary_rows)
    display(df)
else:
    for row in summary_rows:
        print(row)


Unnamed: 0,step,method,time_sec,iterations,eigvals,max_abs_diff_vs_scipy,max_rel_diff_vs_scipy
0,64,torch_lobpcg_rewrite,1.84316,8.0,"[9.108808517456055, 8.084686279296875, 7.77278...",0.02341,0.003003
1,64,torch_lobpcg_old,2.58613,9.0,"[9.108842849731445, 8.08322525024414, 7.767592...",0.028603,0.003669
2,64,scipy_eigsh,2.080705,,"[9.108960151672363, 8.088883399963379, 7.79619...",0.0,
3,128,torch_lobpcg_rewrite,0.438369,2.0,"[11.59397029876709, 11.181607246398926, 9.5937...",0.029282,0.002612
4,128,torch_lobpcg_old,0.909927,3.0,"[11.602474212646484, 11.207901000976562, 9.610...",0.002988,0.000302
5,128,scipy_eigsh,1.165338,,"[11.603415489196777, 11.210888862609863, 9.613...",0.0,
6,192,torch_lobpcg_rewrite,0.657244,3.0,"[14.945653915405273, 13.686734199523926, 10.94...",0.014599,0.001333
7,192,torch_lobpcg_old,0.909496,3.0,"[14.945651054382324, 13.686753273010254, 10.94...",0.009389,0.000857
8,192,scipy_eigsh,2.029147,,"[14.945816040039062, 13.68695068359375, 10.955...",0.0,
9,256,torch_lobpcg_rewrite,0.43841,2.0,"[18.30406951904297, 16.115856170654297, 12.701...",0.040638,0.003189


In [13]:
time_stats = {}
for method in {row["method"] for row in summary_rows}:
    method_times = [row["time_sec"] for row in summary_rows if row["method"] == method]
    time_stats[method] = (float(np.mean(method_times)), float(np.std(method_times)))

print("Mean ± std runtime (s):")
for method, (mean_val, std_val) in sorted(time_stats.items()):
    print(f"  {method:>18}: {mean_val:.3f} ± {std_val:.3f}")

diff_stats = {}
for method in ["torch_lobpcg_rewrite", "torch_lobpcg_old"]:
    diffs = [row["max_abs_diff_vs_scipy"] for row in summary_rows if row["method"] == method]
    if diffs:
        diff_stats[method] = float(np.max(diffs))

if diff_stats:
    print(
        "\nMax abs eigenvalue diff vs SciPy across measurements:"
    )
    for method, value in sorted(diff_stats.items()):
        print(f"  {method:>18}: {value:.3e}")


Mean ± std runtime (s):
         scipy_eigsh: 1.211 ± 0.300
    torch_lobpcg_old: 0.328 ± 0.205
  torch_lobpcg_rewrite: 0.193 ± 0.139

Max abs eigenvalue diff vs SciPy across measurements:
    torch_lobpcg_old: 5.746e-02
  torch_lobpcg_rewrite: 5.780e-02
