Skip to content

Commit

Permalink
Add torch, jax, and big marks to unit tests.
Browse files Browse the repository at this point in the history
These markings make it possible to run the test suite when only one of
the backends are installed.

It also skips big operator tests on on CI.

NOTE: for all future testing, neither torch, jax, cola.torch_fns, nor
cola.jax_fns can be imported in a testing file.

- With the 'not jax' mark, all non-jax tests run.
- With the 'not torch' mark, all non-torch tests run.
- With the 'not big' mark, all non-big LO tests run.
  • Loading branch information
gpleiss committed Aug 23, 2023
1 parent 3345985 commit f6c8e13
Show file tree
Hide file tree
Showing 21 changed files with 449 additions and 347 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
pip install -e .
- name: Test coverage.
run: |
pytest --cov cola --cov-report xml:cov.xml tests/
pytest -m 'not big' --cov cola --cov-report xml:cov.xml tests/
- name: Upload to codecov
uses: codecov/codecov-action@v3
with:
Expand Down
56 changes: 54 additions & 2 deletions cola/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,44 @@
import pytest
import inspect
import itertools
import os
import pytest
from types import ModuleType

from cola.ops import get_library_fns
import numpy as np
import itertools


# Try importing jax and torch, for the get_framework function
try:
from . import jax_fns
except ImportError:
jax_fns = None

Check warning on line 15 in cola/utils_test.py

View check run for this annotation

Codecov / codecov/patch

cola/utils_test.py#L14-L15

Added lines #L14 - L15 were not covered by tests

try:
from . import torch_fns
except ImportError:
torch_fns = None

Check warning on line 20 in cola/utils_test.py

View check run for this annotation

Codecov / codecov/patch

cola/utils_test.py#L19-L20

Added lines #L19 - L20 were not covered by tests


def strip_parens(string):
return string.replace('(', '').replace(')', '')


def _add_marks(case):
# This function is maybe hacky, but it adds marks based on the names of the parameters supplied
# In particular, it adds the 'torch', 'jax', and 'big' marks
case = case if isinstance(case, list) or isinstance(case, tuple) else [case]
marks = []
args = tuple(str(arg) for arg in case)
if any('torch' in arg for arg in args):
marks.append(pytest.mark.torch)
if any('jax' in arg for arg in args):
marks.append(pytest.mark.jax)
if any('big' in arg for arg in args):
marks.append(pytest.mark.big)
return pytest.param(*case, marks=marks)


def parametrize(*cases, ids=None):
""" Expands test cases with pytest.mark.parametrize but with argnames
assumed and ids given by the ids=[str(case) for case in cases] """
Expand All @@ -17,6 +47,9 @@ def parametrize(*cases, ids=None):
else:
all_cases = cases[0]

# Potentially add marks
all_cases = [_add_marks(case) for case in all_cases]

def decorator(test_fn):
argnames = ','.join(inspect.getfullargspec(test_fn).args)
theids = [strip_parens(str(case)) for case in all_cases] if ids is None else ids
Expand Down Expand Up @@ -106,3 +139,22 @@ def generate_clustered_spectrum(clusters, sizes, std=0.025, seed=None, dtype=np.
diag.append(sub_diags)
diag = np.concatenate(diag, axis=0)
return np.sort(diag)[::-1]


def get_xnp(backend: str) -> ModuleType:
match backend:
case "torch":
if torch_fns is None: # There was an import error with torch
raise RuntimeError("Could not import torch. It is likely not installed.")

Check warning on line 148 in cola/utils_test.py

View check run for this annotation

Codecov / codecov/patch

cola/utils_test.py#L148

Added line #L148 was not covered by tests
else:
return torch_fns
case "jax":
if jax_fns is None: # There was an import error with jax
raise RuntimeError("Could not import jax. It is likely not installed.")

Check warning on line 153 in cola/utils_test.py

View check run for this annotation

Codecov / codecov/patch

cola/utils_test.py#L153

Added line #L153 was not covered by tests
else:
from jax.config import config
config.update('jax_platform_name', 'cpu') # Force tests to run tests on CPU
# config.update("jax_enable_x64", True)
return jax_fns
case _:
raise ValueError(f"Unknown backend {backend}.")

Check warning on line 160 in cola/utils_test.py

View check run for this annotation

Codecov / codecov/patch

cola/utils_test.py#L159-L160

Added lines #L159 - L160 were not covered by tests
6 changes: 6 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,9 @@ BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF = False
[pycodestyle]
max-line-length = 100
ignore = E301

[tool:pytest]
markers =
torch: mark a test that uses the PyTorch backend
jax: mark a test that uses the JaX backend
big: mark a test that involves a big (i.e. slow) linear operator
40 changes: 20 additions & 20 deletions tests/algorithms/test_arnoldi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,27 @@
from cola.ops import Product
from cola.ops import Dense
from cola.algorithms.arnoldi import get_householder_vec
from cola import jax_fns
from cola import torch_fns
from cola.fns import lazify
from cola.algorithms.arnoldi import get_arnoldi_matrix
from cola.algorithms.arnoldi import arnoldi_eigs
from cola.algorithms.arnoldi import run_householder_arnoldi
from cola.utils_test import parametrize, relative_error
from cola.utils_test import get_xnp, parametrize, relative_error
from cola.utils_test import generate_spectrum, generate_pd_from_diag
from cola.utils_test import generate_lower_from_diag
from jax.config import config

config.update('jax_platform_name', 'cpu')
# config.update("jax_enable_x64", True)

@parametrize(['torch', 'jax'])
def test_arnoldi_vjp(backend):
if backend == 'torch':
import torch
torch.manual_seed(seed=21)

@parametrize([torch_fns, jax_fns])
def test_arnoldi_vjp(xnp):
xnp = get_xnp(backend)
dtype = xnp.float64
matrix = [[6., 2., 3.], [2., 3., 1.], [3., 1., 4.]]
diag = xnp.Parameter(xnp.array(matrix, dtype=dtype))
diag_soln = xnp.Parameter(xnp.array(matrix, dtype=dtype))
_, unflatten = Dense(diag).flatten()
import torch
torch.manual_seed(seed=21)
x0 = xnp.randn(diag.shape[0], 1)

def f(theta):
Expand Down Expand Up @@ -64,8 +61,9 @@ def f_alt(theta):
assert abs_error < 5e-5


@parametrize([torch_fns, jax_fns])
def test_arnoldi(xnp):
@parametrize(['torch', 'jax'])
def test_arnoldi(backend):
xnp = get_xnp(backend)
dtype = xnp.complex64
diag = generate_spectrum(coeff=0.5, scale=1.0, size=4, dtype=np.float32)
A = xnp.array(generate_lower_from_diag(diag, dtype=diag.dtype, seed=48), dtype=dtype)
Expand All @@ -82,17 +80,17 @@ def test_arnoldi(xnp):
assert rel_error < 1e-3


# @parametrize([torch_fns])
@parametrize([jax_fns])
def test_householder_arnoldi_decomp(xnp):
@parametrize(['jax'])
def test_householder_arnoldi_decomp(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = generate_spectrum(coeff=0.5, scale=1.0, size=10, dtype=np.float32) - 0.5
A = xnp.array(generate_pd_from_diag(diag, dtype=diag.dtype, seed=21), dtype=dtype)
rhs = xnp.randn(A.shape[1], 1, dtype=dtype)
# A_np, rhs_np = np.array(A, dtype=np.complex128), np.array(rhs[:, 0], dtype=np.complex128)
A_np, rhs_np = np.array(A, dtype=np.float64), np.array(rhs[:, 0], dtype=np.float64)
# Q_sol, H_sol = run_householder_arnoldi(A, rhs, A.shape[0], np.float64, xnp)
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64, jax_fns)
Q_sol, H_sol = run_householder_arnoldi_np(A_np, rhs_np, A.shape[0], np.float64, xnp)

# fn = run_householder_arnoldi
fn = xnp.jit(run_householder_arnoldi, static_argnums=(0, 2))
Expand All @@ -103,8 +101,9 @@ def test_householder_arnoldi_decomp(xnp):
assert rel_error < 1e-5


@parametrize([torch_fns]) # jax does not have complex128
def test_get_arnoldi_matrix(xnp):
@parametrize(['torch']) # jax does not have complex128
def test_get_arnoldi_matrix(backend):
xnp = get_xnp(backend)
dtype = xnp.complex128 # double precision on real and complex coordinates to achieve 1e-12 tol
diag = generate_spectrum(coeff=0.5, scale=1.0, size=20, dtype=np.float32) - 0.5
A = xnp.array(generate_pd_from_diag(diag, dtype=diag.dtype, seed=21), dtype=dtype)
Expand Down Expand Up @@ -133,12 +132,13 @@ def test_get_arnoldi_matrix(xnp):
assert rel_error < 1e-12


def test_numpy_arnoldi():
@parametrize(['jax'])
def test_numpy_arnoldi(backend):
xnp = get_xnp(backend)
float_formatter = "{:.2f}".format
np.set_printoptions(formatter={'float_kind': float_formatter})
# dtype = np.complex64
dtype = np.float64
xnp = jax_fns
diag = generate_spectrum(coeff=0.5, scale=1.0, size=10, dtype=np.float32)
A = np.array(generate_lower_from_diag(diag, dtype=diag.dtype, seed=48), dtype=dtype)
rhs = np.random.normal(size=(A.shape[0], ))
Expand Down
42 changes: 22 additions & 20 deletions tests/algorithms/test_cg.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,23 @@
import numpy as np
from cola import jax_fns
from cola import torch_fns
from cola.fns import lazify
from cola.ops import Identity
from cola.ops import Diagonal
from cola.algorithms.preconditioners import NystromPrecond
from cola.algorithms.cg import run_batched_cg
from cola.algorithms.cg import run_batched_tracking_cg
from cola.algorithms.cg import run_cg
from cola.utils_test import parametrize, relative_error
from cola.utils_test import get_xnp, parametrize, relative_error
from cola.utils_test import generate_spectrum, generate_pd_from_diag
from cola.utils_test import generate_diagonals
# from tests.algorithms.test_lanczos import construct_tridiagonal
from jax.config import config

config.update('jax_platform_name', 'cpu')
# config.update("jax_enable_x64", True)

_tol = 1e-7


@parametrize([torch_fns, jax_fns])
def test_cg_vjp(xnp):
@parametrize(['torch', 'jax'])
def test_cg_vjp(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = xnp.Parameter(xnp.array([3., 4., 5.], dtype=dtype))
diag_soln = xnp.Parameter(xnp.array([3., 4., 5.], dtype=dtype))
Expand Down Expand Up @@ -64,8 +60,9 @@ def f_alt(theta):
assert rel_error < _tol * 10


# @parametrize([torch_fns, jax_fns])
# def test_cg_lanczos_coeffs(xnp):
# @parametrize(["torch", "jax"])
# def test_cg_lanczos_coeffs(backend):
# xnp = get_xnp(backend)
# dtype = xnp.float32
# A = xnp.diag(xnp.array([3., 4., 5.], dtype=dtype))
# rhs = xnp.ones(shape=(A.shape[0], 1), dtype=dtype)
Expand Down Expand Up @@ -97,8 +94,9 @@ def f_alt(theta):
# assert rel_error < _tol


@parametrize([torch_fns, jax_fns])
def test_cg_complex(xnp):
@parametrize(["torch", "jax"])
def test_cg_complex(backend):
xnp = get_xnp(backend)
dtype = xnp.complex64
diag = generate_spectrum(coeff=0.5, scale=1.0, size=25, dtype=np.float32)
A = xnp.array(generate_diagonals(diag, seed=48), dtype=dtype)
Expand All @@ -116,8 +114,9 @@ def test_cg_complex(xnp):
assert rel_error < 1e-5


@parametrize([torch_fns, jax_fns])
def test_cg_random(xnp):
@parametrize(["torch", "jax"])
def test_cg_random(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = generate_spectrum(coeff=0.75, scale=1.0, size=25, dtype=np.float32)
A = xnp.array(generate_pd_from_diag(diag, dtype=diag.dtype), dtype=dtype)
Expand All @@ -136,8 +135,9 @@ def test_cg_random(xnp):
assert rel_error < 1e-6


@parametrize([torch_fns, jax_fns])
def test_cg_repeated_eig(xnp):
@parametrize(["torch", "jax"])
def test_cg_repeated_eig(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = [1. for _ in range(10)] + [0.5 for _ in range(10)] + [0.25 for _ in range(10)]
diag = np.array(diag, dtype=np.float32)
Expand All @@ -157,8 +157,9 @@ def test_cg_repeated_eig(xnp):
assert rel_error < _tol * 10


@parametrize([torch_fns, jax_fns])
def test_cg_track_easy(xnp):
@parametrize(["torch", "jax"])
def test_cg_track_easy(backend):
xnp = get_xnp(backend)
dtype = xnp.float64
A = xnp.diag(xnp.array([3., 4., 5.], dtype=dtype))
rhs = [[1, 3], [1, 4], [1, 5]]
Expand All @@ -176,8 +177,9 @@ def test_cg_track_easy(xnp):
assert rel_error < _tol

# Marc: I disabled this test because it seems to test batched linear operators?
# @parametrize([torch_fns, jax_fns])
# def test_cg_easy_case(xnp):
# @parametrize(["torch", "jax"])
# def test_cg_easy_case(backend):
# xnp = get_xnp(backend)
# dtype = xnp.float64
# A = xnp.diag(xnp.array([3., 4., 5.], dtype=dtype))
# rhs = xnp.array([[1.0 for _ in range(A.shape[0])]], dtype=dtype).T
Expand Down
22 changes: 10 additions & 12 deletions tests/algorithms/test_gmres.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
import numpy as np
from cola import jax_fns
from cola import torch_fns
from cola.ops import Identity
from cola.ops import Diagonal
from cola.fns import lazify
from cola.linalg.inverse import inverse
from cola.algorithms.gmres import gmres
from cola.algorithms.gmres import gmres_fwd
from cola.utils_test import parametrize, relative_error
from cola.utils_test import get_xnp, parametrize, relative_error
from cola.utils_test import generate_spectrum, generate_pd_from_diag
from jax.config import config

config.update('jax_platform_name', 'cpu')


@parametrize([torch_fns, jax_fns])
def test_gmres_vjp(xnp):
@parametrize(['torch', 'jax'])
def test_gmres_vjp(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = xnp.Parameter(xnp.array([3., 4., 5.], dtype=dtype))
diag_soln = xnp.Parameter(xnp.array([3., 4., 5.], dtype=dtype))
Expand Down Expand Up @@ -57,8 +53,9 @@ def f_alt(theta):
assert rel_error < 1e-6


@parametrize([torch_fns, jax_fns])
def test_gmres_random(xnp):
@parametrize(['torch', 'jax'])
def test_gmres_random(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
diag = generate_spectrum(coeff=0.5, scale=1.0, size=25, dtype=np.float32)
A = xnp.array(generate_pd_from_diag(diag, dtype=diag.dtype), dtype=dtype)
Expand All @@ -73,8 +70,9 @@ def test_gmres_random(xnp):
assert rel_error < 5e-4


@parametrize([torch_fns, jax_fns])
def test_gmres_easy(xnp):
@parametrize(['torch', 'jax'])
def test_gmres_easy(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
A = xnp.diag(xnp.array([3., 4., 5.], dtype=dtype))
rhs = [[1], [1], [1]]
Expand Down
12 changes: 4 additions & 8 deletions tests/algorithms/test_iram.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import numpy as np
from cola import jax_fns
from cola import torch_fns
from cola.fns import lazify
from cola.algorithms.iram import iram
from cola.utils_test import parametrize, relative_error
from cola.utils_test import get_xnp, parametrize, relative_error
from cola.utils_test import generate_spectrum, generate_pd_from_diag
from jax.config import config

config.update('jax_platform_name', 'cpu')


@parametrize([torch_fns, jax_fns])
def test_iram_random(xnp):
@parametrize(['torch', 'jax'])
def test_iram_random(backend):
xnp = get_xnp(backend)
dtype = xnp.float32
np_dtype = np.float32
diag = generate_spectrum(coeff=0.5, scale=1.0, size=10, dtype=np_dtype)
Expand Down
Loading

0 comments on commit f6c8e13

Please sign in to comment.