Skip to content

Commit

Permalink
Update to support SuiteSparse:GraphBLAS 7 and 8 (#456)
Browse files Browse the repository at this point in the history
  • Loading branch information
eriknw committed Jul 5, 2023
1 parent 79cf3fa commit f14cbac
Show file tree
Hide file tree
Showing 23 changed files with 890 additions and 47 deletions.
36 changes: 21 additions & 15 deletions .github/workflows/test_and_build.yml
Expand Up @@ -131,9 +131,9 @@ jobs:
source
upstream
weights: |
1000000
1000000
1000000
1
1
1
1
- name: Setup mamba
uses: conda-incubator/setup-miniconda@v2
Expand Down Expand Up @@ -175,22 +175,22 @@ jobs:
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", "=2.0", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))')
elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.9') }} == true ]]; then
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.2", "=1.3", "=1.4", "=1.5", "=2.0", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))')
elif [[ ${{ startsWith(steps.pyver.outputs.selected, '3.10') }} == true ]]; then
npver=$(python -c 'import random ; print(random.choice(["=1.21", "=1.22", "=1.23", "=1.24", "=1.25", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.8", "=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.3", "=1.4", "=1.5", "=2.0", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))')
else # Python 3.11
npver=$(python -c 'import random ; print(random.choice(["=1.23", "=1.24", "=1.25", ""]))')
spver=$(python -c 'import random ; print(random.choice(["=1.9", "=1.10", "=1.11", ""]))')
pdver=$(python -c 'import random ; print(random.choice(["=1.5", "=2.0", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", ""]))')
akver=$(python -c 'import random ; print(random.choice(["=1.10", "=2.0", "=2.1", "=2.2", "=2.3", ""]))')
fi
if [[ ${{ steps.sourcetype.outputs.selected }} == "source" || ${{ steps.sourcetype.outputs.selected }} == "upstream" ]]; then
# TODO: there are currently issues with some numpy versions when
Expand All @@ -204,13 +204,13 @@ jobs:
# But, it's still useful for us to test with different versions!
psg=""
if [[ ${{ steps.sourcetype.outputs.selected}} == "conda-forge" ]] ; then
psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", "=7.4.3.2"]))')
psgver=$(python -c 'import random ; print(random.choice(["=7.4.0", "=7.4.1", "=7.4.2", "=7.4.3.0", "=7.4.3.1", "=7.4.3.2", "=8.0.2.1", ""]))')
psg=python-suitesparse-graphblas${psgver}
elif [[ ${{ steps.sourcetype.outputs.selected}} == "wheel" ]] ; then
psgver=$(python -c 'import random ; print(random.choice(["==7.4.3.2"]))')
psgver=$(python -c 'import random ; print(random.choice(["==7.4.3.2", "==8.0.2.1", ""]))')
elif [[ ${{ steps.sourcetype.outputs.selected}} == "source" ]] ; then
# These should be exact versions
psgver=$(python -c 'import random ; print(random.choice(["==7.4.0.0", "==7.4.1.0", "==7.4.2.0", "==7.4.3.0", "==7.4.3.1", "==7.4.3.2"]))')
psgver=$(python -c 'import random ; print(random.choice(["==7.4.0.0", "==7.4.1.0", "==7.4.2.0", "==7.4.3.0", "==7.4.3.1", "==7.4.3.2", "==8.0.2.1", ""]))')
else
psgver=""
fi
Expand Down Expand Up @@ -260,17 +260,18 @@ jobs:
numba=numba${numbaver}
sparse=sparse${sparsever}
fi
echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psgver${psgver}"
echo "versions: np${npver} sp${spver} pd${pdver} ak${akver} nx${nxver} numba${numbaver} yaml${yamlver} sparse${sparsever} psg${psgver}"
set -x # echo on
$(command -v mamba || command -v conda) install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig tomli \
$(command -v mamba || command -v conda) install packaging pytest coverage coveralls=3.3.1 pytest-randomly cffi donfig tomli c-compiler make \
pyyaml${yamlver} ${sparse} pandas${pdver} scipy${spver} numpy${npver} ${awkward} \
networkx${nxver} ${numba} ${fmm} ${psg} \
${{ matrix.slowtask == 'pytest_bizarro' && 'black' || '' }} \
${{ matrix.slowtask == 'notebooks' && 'matplotlib nbconvert jupyter "ipython>=7"' || '' }} \
${{ steps.sourcetype.outputs.selected == 'upstream' && 'cython' || '' }} \
${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas=7.4"' || '' }} \
${{ contains(steps.pyver.outputs.selected, 'pypy') && 'pypy' || '' }}
${{ steps.sourcetype.outputs.selected != 'wheel' && '"graphblas>=7.4"' || '' }} \
${{ contains(steps.pyver.outputs.selected, 'pypy') && 'pypy' || '' }} \
${{ matrix.os == 'windows-latest' && 'cmake' || 'm4' }}
- name: Build extension module
run: |
if [[ ${{ steps.sourcetype.outputs.selected }} == "wheel" ]]; then
Expand All @@ -291,6 +292,12 @@ jobs:
pip install --no-deps git+https://github.com/GraphBLAS/python-suitesparse-graphblas.git@main#egg=suitesparse-graphblas
fi
pip install --no-deps -e .
- name: python-suitesparse-graphblas tests
run: |
# Don't use our conftest.py ; allow `test_print_jit_config` to fail if it doesn't exist
(cd ..
pytest --pyargs suitesparse_graphblas -s -k test_print_jit_config || true
pytest -v --pyargs suitesparse_graphblas)
- name: Unit tests
run: |
A=${{ needs.rngs.outputs.mapnumpy == 'A' || '' }} ; B=${{ needs.rngs.outputs.mapnumpy == 'B' || '' }}
Expand Down Expand Up @@ -318,7 +325,6 @@ jobs:
if [[ $H && $normal ]] ; then if [[ $macos ]] ; then echo " $vanilla" ; elif [[ $windows ]] ; then echo " $suitesparse" ; fi ; fi)$( \
if [[ $H && $bizarro ]] ; then if [[ $macos ]] ; then echo " $suitesparse" ; elif [[ $windows ]] ; then echo " $vanilla" ; fi ; fi)
echo ${args}
(cd .. && pytest -v --pyargs suitesparse_graphblas) # Don't use our conftest.py
set -x # echo on
coverage run -m pytest --color=yes --randomly -v ${args} \
${{ matrix.slowtask == 'pytest_normal' && '--runslow' || '' }}
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Expand Up @@ -51,7 +51,7 @@ repos:
- id: isort
# Let's keep `pyupgrade` even though `ruff --fix` probably does most of it
- repo: https://github.com/asottile/pyupgrade
rev: v3.7.0
rev: v3.8.0
hooks:
- id: pyupgrade
args: [--py38-plus]
Expand All @@ -66,7 +66,7 @@ repos:
- id: black
- id: black-jupyter
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.275
rev: v0.0.277
hooks:
- id: ruff
args: [--fix-only, --show-fixes]
Expand Down Expand Up @@ -94,7 +94,7 @@ repos:
additional_dependencies: [tomli]
files: ^(graphblas|docs)/
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.275
rev: v0.0.277
hooks:
- id: ruff
- repo: https://github.com/sphinx-contrib/sphinx-lint
Expand Down
2 changes: 1 addition & 1 deletion docs/env.yml
Expand Up @@ -8,7 +8,7 @@ dependencies:
# python-graphblas dependencies
- donfig
- numba
- python-suitesparse-graphblas>=7.4.0.0,<8
- python-suitesparse-graphblas>=7.4.0.0
- pyyaml
# extra dependencies
- matplotlib
Expand Down
1 change: 1 addition & 0 deletions graphblas/binary/ss.py
@@ -1,4 +1,5 @@
from ..core import operator
from ..core.ss.binary import register_new # noqa: F401

_delayed = {}

Expand Down
2 changes: 1 addition & 1 deletion graphblas/core/dtypes.py
Expand Up @@ -22,7 +22,7 @@ def __init__(self, name, gb_obj, gb_name, c_type, numba_type, np_type):
self.gb_name = gb_name
self.c_type = c_type
self.numba_type = numba_type
self.np_type = np.dtype(np_type)
self.np_type = np.dtype(np_type) if np_type is not None else None

def __repr__(self):
return self.name
Expand Down
3 changes: 3 additions & 0 deletions graphblas/core/ss/__init__.py
@@ -0,0 +1,3 @@
import suitesparse_graphblas as _ssgb

_IS_SSGB7 = _ssgb.__version__.split(".", 1)[0] == "7"
72 changes: 72 additions & 0 deletions graphblas/core/ss/binary.py
@@ -0,0 +1,72 @@
from ... import backend
from ...dtypes import lookup_dtype
from ...exceptions import check_status_carg
from .. import NULL, ffi, lib
from ..operator.base import TypedOpBase
from ..operator.binary import BinaryOp, TypedUserBinaryOp
from . import _IS_SSGB7

ffi_new = ffi.new


class TypedJitBinaryOp(TypedOpBase):
__slots__ = "_monoid", "_jit_c_definition"
opclass = "BinaryOp"

def __init__(self, parent, name, type_, return_type, gb_obj, jit_c_definition, dtype2=None):
super().__init__(parent, name, type_, return_type, gb_obj, name, dtype2=dtype2)
self._monoid = None
self._jit_c_definition = jit_c_definition

@property
def jit_c_definition(self):
return self._jit_c_definition

monoid = TypedUserBinaryOp.monoid
commutes_to = TypedUserBinaryOp.commutes_to
_semiring_commutes_to = TypedUserBinaryOp._semiring_commutes_to
is_commutative = TypedUserBinaryOp.is_commutative
type2 = TypedUserBinaryOp.type2
__call__ = TypedUserBinaryOp.__call__


def register_new(name, jit_c_definition, left_type, right_type, ret_type):
if backend != "suitesparse": # pragma: no cover (safety)
raise RuntimeError(
"`gb.binary.ss.register_new` invalid when not using 'suitesparse' backend"
)
if _IS_SSGB7:
# JIT was introduced in SuiteSparse:GraphBLAS 8.0
import suitesparse_graphblas as ssgb

raise RuntimeError(
"JIT was added to SuiteSparse:GraphBLAS in version 8; "
f"current version is {ssgb.__version__}"
)
left_type = lookup_dtype(left_type)
right_type = lookup_dtype(right_type)
ret_type = lookup_dtype(ret_type)
name = name if name.startswith("ss.") else f"ss.{name}"
module, funcname = BinaryOp._remove_nesting(name)

rv = BinaryOp(name)
gb_obj = ffi_new("GrB_BinaryOp*")
check_status_carg(
lib.GxB_BinaryOp_new(
gb_obj,
NULL,
ret_type._carg,
left_type._carg,
right_type._carg,
ffi_new("char[]", funcname.encode()),
ffi_new("char[]", jit_c_definition.encode()),
),
"BinaryOp",
gb_obj[0],
)
op = TypedJitBinaryOp(
rv, funcname, left_type, ret_type, gb_obj[0], jit_c_definition, dtype2=right_type
)
rv._add(op)
setattr(module, funcname, rv)
return rv
16 changes: 8 additions & 8 deletions graphblas/core/ss/config.py
Expand Up @@ -65,7 +65,7 @@ def __getitem__(self, key):
raise KeyError(key)
key_obj, ctype = self._options[key]
is_bool = ctype == "bool"
if is_context := (key in self._context_keys): # pragma: no cover (suitesparse 8)
if is_context := (key in self._context_keys):
get_function_base = self._context_get_function
else:
get_function_base = self._get_function
Expand All @@ -76,14 +76,14 @@ def __getitem__(self, key):
get_function_name = f"{get_function_base}_INT64"
elif ctype.startswith("double"):
get_function_name = f"{get_function_base}_FP64"
elif ctype.startswith("char"): # pragma: no cover (suitesparse 8)
elif ctype.startswith("char"):
get_function_name = f"{get_function_base}_CHAR"
else: # pragma: no cover (sanity)
raise ValueError(ctype)
get_function = getattr(lib, get_function_name)
is_array = "[" in ctype
val_ptr = ffi.new(ctype if is_array else f"{ctype}*")
if is_context: # pragma: no cover (suitesparse 8)
if is_context:
info = get_function(self._context._carg, key_obj, val_ptr)
elif self._parent is None:
info = get_function(key_obj, val_ptr)
Expand All @@ -105,7 +105,7 @@ def __getitem__(self, key):
return rv
if is_bool:
return bool(val_ptr[0])
if ctype.startswith("char"): # pragma: no cover (suitesparse 8)
if ctype.startswith("char"):
return ffi.string(val_ptr[0]).decode()
return val_ptr[0]
raise _error_code_lookup[info](f"Failed to get info for {key!r}") # pragma: no cover
Expand All @@ -117,7 +117,7 @@ def __setitem__(self, key, val):
if key in self._read_only:
raise ValueError(f"Config option {key!r} is read-only")
key_obj, ctype = self._options[key]
if is_context := (key in self._context_keys): # pragma: no cover (suitesparse 8)
if is_context := (key in self._context_keys):
set_function_base = self._context_set_function
else:
set_function_base = self._set_function
Expand All @@ -130,7 +130,7 @@ def __setitem__(self, key, val):
set_function_name = f"{set_function_base}_INT64_ARRAY"
elif ctype.startswith("double["):
set_function_name = f"{set_function_base}_FP64_ARRAY"
elif ctype.startswith("char"): # pragma: no cover (suitesparse 8)
elif ctype.startswith("char"):
set_function_name = f"{set_function_base}_CHAR"
else: # pragma: no cover (sanity)
raise ValueError(ctype)
Expand Down Expand Up @@ -174,11 +174,11 @@ def __setitem__(self, key, val):
f"expected {size}, got {vals.size}: {val}"
)
val_obj = ffi.from_buffer(ctype, vals)
elif ctype.startswith("char"): # pragma: no cover (suitesparse 8)
elif ctype.startswith("char"):
val_obj = ffi.new("char[]", val.encode())
else:
val_obj = ffi.cast(ctype, val)
if is_context: # pragma: no cover (suitesparse 8)
if is_context:
if self._context is None:
from .context import Context

Expand Down

0 comments on commit f14cbac

Please sign in to comment.