Skip to content

Commit

Permalink
Merge in numba fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-wieser committed Jun 2, 2020
2 parents 4da4f7d + e080fa6 commit 96d7e11
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ x-clifford-templates:
IPython \
h5py;
conda install -c conda-forge sparse;
conda install -c numba "numba>=0.45.1,!=0.49.0rc1,<0.50.0";
conda install -c numba "numba>=0.45.1";
else
pip install IPython;
fi
Expand Down
8 changes: 6 additions & 2 deletions clifford/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@
import numpy as np
import numba
import sparse
try:
from numba.np import numpy_support as _numpy_support
except ImportError:
import numba.numpy_support as _numpy_support


from clifford.io import write_ga_file, read_ga_file # noqa: F401
Expand Down Expand Up @@ -143,8 +147,8 @@ def get_mult_function(mt: sparse.COO, gradeList,


def _get_mult_function_result_type(a: numba.types.Type, b: numba.types.Type, mt: np.dtype):
a_dt = numba.numpy_support.as_dtype(getattr(a, 'dtype', a))
b_dt = numba.numpy_support.as_dtype(getattr(b, 'dtype', b))
a_dt = _numpy_support.as_dtype(getattr(a, 'dtype', a))
b_dt = _numpy_support.as_dtype(getattr(b, 'dtype', b))
return np.result_type(a_dt, mt, b_dt)


Expand Down
5 changes: 3 additions & 2 deletions clifford/_bit_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import numba
import numba.extending
import numba.types
import numba.config

from . import _numba_utils


def set_bit_indices(x: int) -> Iterator[int]:
Expand All @@ -33,7 +34,7 @@ def impl(cgctx, builder, sig, args):
return sig, impl


if numba.config.DISABLE_JIT:
if _numba_utils.DISABLE_JIT:
def count_set_bits(bitmap: int) -> int:
""" Counts the number of bits set to 1 in bitmap """
count = 0
Expand Down
15 changes: 11 additions & 4 deletions clifford/_numba_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@

import numba

try:
from numba.core.config import DISABLE_JIT
import numba.core.serialize as _serialize
except ImportError:
from numba.config import DISABLE_JIT
import numba.serialize as _serialize


class pickleable_function:
"""
Expand All @@ -32,13 +39,13 @@ def __new__(cls, func):

@classmethod
def _rebuild(cls, *args):
return cls(numba.serialize._rebuild_function(*args))
return cls(_serialize._rebuild_function(*args))

def __reduce__(self):
globs = numba.serialize._get_function_globals_for_reduction(self.__func)
globs = _serialize._get_function_globals_for_reduction(self.__func)
return (
self._rebuild,
numba.serialize._reduce_function(self.__func, globs)
_serialize._reduce_function(self.__func, globs)
)

def __call__(self, *args, **kwargs):
Expand Down Expand Up @@ -69,7 +76,7 @@ def __call__(self, *args):
return func(*args)


if not numba.config.DISABLE_JIT:
if not DISABLE_JIT:
njit = numba.njit
generated_jit = numba.generated_jit
else:
Expand Down
4 changes: 2 additions & 2 deletions clifford/test/test_algebra_initialisation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import numpy as np
import pytest
import numba

from clifford import Cl, conformalize, _powerset
from clifford._numba_utils import DISABLE_JIT

too_slow_without_jit = pytest.mark.skipif(
numba.config.DISABLE_JIT, reason="test is too slow without JIT"
DISABLE_JIT, reason="test is too slow without JIT"
)


Expand Down
9 changes: 5 additions & 4 deletions clifford/test/test_dg3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@

import pytest
import numpy as np
import numba

from clifford._numba_utils import DISABLE_JIT

too_slow_without_jit = pytest.mark.skipif(
numba.config.DISABLE_JIT, reason="test is too slow without JIT"
DISABLE_JIT, reason="test is too slow without JIT"
)


Expand Down Expand Up @@ -48,7 +49,7 @@ def test_up_down(self):
from clifford.dg3c import up, down

rng = np.random.RandomState()
for i in range(1 if numba.config.DISABLE_JIT else 100):
for i in range(1 if DISABLE_JIT else 100):
pnt_vector = rng.randn(3)
pnt = up(pnt_vector)
res = down(100*pnt)
Expand All @@ -66,7 +67,7 @@ def test_up_down_cga1(self):

rng = np.random.RandomState()
pnt_vector = rng.randn(3)
for i in range(10 if numba.config.DISABLE_JIT else 100):
for i in range(10 if DISABLE_JIT else 100):
pnt = up_cga1(pnt_vector)
res = down_cga1(100*pnt)
np.testing.assert_allclose(res, pnt_vector)
Expand Down
6 changes: 4 additions & 2 deletions clifford/test/test_g3c_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import random
from functools import reduce
import time
import functools


import numpy as np
Expand All @@ -22,10 +23,11 @@
from clifford.tools.g3c.model_matching import *
from clifford.tools.g3 import random_euc_mv
from clifford.tools.g3c.GAOnline import draw_objects, GAScene, GanjaScene
import functools
from clifford._numba_utils import DISABLE_JIT


too_slow_without_jit = pytest.mark.skipif(
numba.config.DISABLE_JIT, reason="test is too slow without JIT"
DISABLE_JIT, reason="test is too slow without JIT"
)


Expand Down
13 changes: 7 additions & 6 deletions clifford/test/test_tools.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from clifford import Cl

import unittest
import pytest

from clifford.tools import orthoFrames2Versor as of2v
import pytest
import numpy as np
from numpy import exp, float64, testing
import numba

from numpy import exp, float64, testing
from clifford import Cl
from clifford.tools import orthoFrames2Versor as of2v
from clifford._numba_utils import DISABLE_JIT


too_slow_without_jit = pytest.mark.skipif(
numba.config.DISABLE_JIT, reason="test is too slow without JIT"
DISABLE_JIT, reason="test is too slow without JIT"
)


Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
install_requires=[
'numpy',
'scipy',
# 0.50.0 is a breaking API change
'numba > 0.46, != 0.49.0rc1, < 0.50.0',
'numba > 0.46',
'h5py',
'sparse',
],
Expand Down

0 comments on commit 96d7e11

Please sign in to comment.