# SuiteSparse Matrix Collection with torch-sla

This notebook demonstrates how to use **torch-sla** to solve linear systems from the [SuiteSparse Matrix Collection](https://sparse.tamu.edu/) (formerly University of Florida Sparse Matrix Collection).

## Key Features

- Download matrices directly from SuiteSparse
- Load Matrix Market (.mtx) format files
- Solve `Ax = b` with multiple backends
- Compare iterative vs direct solvers
- Full gradient support via autograd


In [None]:
import sys
sys.path.insert(0, '..')
sys.path.insert(0, '../benchmarks')

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import time

from torch_sla import SparseTensor
from torch_sla.io import load_mtx, load_mtx_info

# Check CUDA availability
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name()}")


## 1. Downloading SuiteSparse Matrices

We provide a utility to download matrices from the SuiteSparse Matrix Collection.


In [None]:
from benchmark_suitesparse import download_matrix, list_available_matrices, POPULAR_MATRICES

# List some popular matrices
print("Some popular matrices for benchmarking:")
print("="*50)
for name, (group, _) in list(POPULAR_MATRICES.items())[:10]:
    print(f"  {name:<15} (group: {group})")


In [None]:
# Download a matrix
mtx_path = download_matrix("bcsstk01")

# View matrix info without loading
info = load_mtx_info(mtx_path)
print(f"\nMatrix info:")
for key, val in info.items():
    print(f"  {key}: {val}")


## 2. Loading and Solving with torch-sla


In [None]:
# Load matrix
A = load_mtx(mtx_path, dtype=torch.float64, device=device)

print(f"Loaded SparseTensor:")
print(f"  Shape: {A.shape}")
print(f"  NNZ: {A.nnz}")
print(f"  Device: {A.device}")
print(f"  Dtype: {A.dtype}")


In [None]:
# Create RHS vector (b = A @ x_true, so we know the exact solution)
n = A.shape[0]
x_true = torch.ones(n, dtype=torch.float64, device=device)
b = A @ x_true

# Solve Ax = b (auto-selects best backend)
x = A.solve(b)

# Check error
error = torch.norm(x - x_true) / torch.norm(x_true)
print(f"Relative error: {error.item():.2e}")


## 3. Comparing Backends and Methods


In [None]:
def benchmark_solver(A, b, backend, method, num_runs=3):
    """Benchmark a solver configuration."""
    # Warmup
    x = A.solve(b, backend=backend, method=method)
    if device == 'cuda':
        torch.cuda.synchronize()
    
    # Benchmark
    times = []
    for _ in range(num_runs):
        if device == 'cuda':
            torch.cuda.synchronize()
        start = time.perf_counter()
        x = A.solve(b, backend=backend, method=method)
        if device == 'cuda':
            torch.cuda.synchronize()
        times.append((time.perf_counter() - start) * 1000)
    
    # Compute error
    residual = A @ x - b
    rel_error = (torch.norm(residual) / torch.norm(b)).item()
    
    return {
        'backend': backend,
        'method': method,
        'time_ms': np.mean(times),
        'rel_error': rel_error,
    }

# Test different configurations
if device == 'cuda':
    configs = [('cudss', 'cholesky'), ('cudss', 'lu'), ('pytorch', 'cg'), ('pytorch', 'bicgstab')]
else:
    configs = [('scipy', 'superlu'), ('pytorch', 'cg'), ('pytorch', 'bicgstab')]

print(f"{'Backend':<12} {'Method':<12} {'Time (ms)':<12} {'Rel Error':<12}")
print("-" * 48)

results = []
for backend, method in configs:
    try:
        r = benchmark_solver(A, b, backend, method)
        results.append(r)
        print(f"{r['backend']:<12} {r['method']:<12} {r['time_ms']:<12.2f} {r['rel_error']:<12.2e}")
    except Exception as e:
        print(f"{backend:<12} {method:<12} {'FAILED':<12} {str(e)[:20]}")


## 4. Larger Matrix: apache1 (80,800 Ã— 80,800)


In [None]:
# Download and load a larger matrix
apache_path = download_matrix("apache1")
apache_info = load_mtx_info(apache_path)

print(f"Matrix: apache1")
print(f"  Shape: {apache_info['shape']}")
print(f"  NNZ: {apache_info['nnz']}")
print(f"  Symmetry: {apache_info['symmetry']}")

# Load to device
A_large = load_mtx(apache_path, dtype=torch.float64, device=device)
print(f"\nLoaded: shape={A_large.shape}, nnz={A_large.nnz}")

# Create problem
n = A_large.shape[0]
x_true = torch.ones(n, dtype=torch.float64, device=device)
b_large = A_large @ x_true

# Benchmark on large matrix
print(f"\nBenchmark on apache1 ({n:,} DOF, {A_large.nnz:,} NNZ):")
print(f"{'Backend':<12} {'Method':<12} {'Time (ms)':<12} {'Rel Error':<12}")
print("-" * 48)

for backend, method in configs:
    try:
        r = benchmark_solver(A_large, b_large, backend, method, num_runs=3)
        print(f"{r['backend']:<12} {r['method']:<12} {r['time_ms']:<12.2f} {r['rel_error']:<12.2e}")
    except Exception as e:
        print(f"{backend:<12} {method:<12} {'FAILED':<12} {str(e)[:20]}")


## 5. Gradient Support (Differentiable Solving)

torch-sla supports automatic differentiation through the solve operation using the adjoint method.


In [None]:
# Small matrix for gradient demo
A_small = load_mtx(download_matrix("bcsstk01"), dtype=torch.float64, device=device)
n = A_small.shape[0]

# Create b with gradient tracking
b_grad = torch.randn(n, dtype=torch.float64, device=device, requires_grad=True)

# Solve with autograd
x = A_small.solve(b_grad)

# Compute loss and backprop
loss = x.sum()
loss.backward()

print(f"Input b shape: {b_grad.shape}")
print(f"Solution x shape: {x.shape}")
print(f"Gradient db shape: {b_grad.grad.shape}")
print(f"\nGradient norm: {b_grad.grad.norm().item():.4f}")


## 6. Visualization

Visualize the sparsity pattern and benchmark results.


In [None]:
# Visualize sparsity pattern
def plot_spy(A, title="Sparsity Pattern"):
    """Plot the sparsity pattern of a SparseTensor."""
    row = A.row_indices.cpu().numpy()
    col = A.col_indices.cpu().numpy()
    
    fig, ax = plt.subplots(figsize=(6, 6))
    ax.scatter(col, row, s=2, c='navy', marker='s')
    ax.set_xlim(-0.5, A.shape[1] - 0.5)
    ax.set_ylim(A.shape[0] - 0.5, -0.5)
    ax.set_aspect('equal')
    ax.set_xlabel('Column')
    ax.set_ylabel('Row')
    ax.set_title(f"{title}\nShape: {A.shape}, NNZ: {A.nnz}")
    plt.tight_layout()
    return fig

fig = plot_spy(A_small, "bcsstk01 - Stiffness Matrix")
plt.show()


## Summary

**Key Takeaways:**

1. **torch-sla** provides a unified interface for solving sparse linear systems
2. **Multiple backends**: scipy (CPU), cudss (CUDA direct), pytorch (CUDA iterative)
3. **For small-medium problems (<2M DOF)**: Use `cudss+cholesky` for SPD matrices
4. **For large problems (>2M DOF)**: Use `pytorch+cg` - scales to 100M+ DOF
5. **Full gradient support**: Differentiable solving via adjoint method
6. **Trade-off**: Direct solvers = machine precision (~1e-15), Iterative = ~1e-6 but much faster


rix