Skip to content

Commit

Permalink
Replace generated_jit with overload
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Mar 17, 2024
1 parent 440d064 commit f9f6f64
Show file tree
Hide file tree
Showing 9 changed files with 84 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8', '3.9', '3.10', '3.11']
python-version: [ '3.8', '3.9', '3.10', '3.11', '3.12']

name: Test Interpolation.py (Python ${{ matrix.python-version }})
steps:
Expand Down
1 change: 0 additions & 1 deletion examples/example_mlinterp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import numpy as np

from numba import generated_jit
import ast

C = ((0.1, 0.2), (0.1, 0.2))
Expand Down
44 changes: 36 additions & 8 deletions interpolation/multilinear/fungen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numba
import numpy as np
from numba import float64, int64
from numba import generated_jit, njit
from numba import njit
import ast

from numba.extending import overload
Expand All @@ -25,8 +25,12 @@ def clamp(x, a, b):


# returns the index of a 1d point along a 1d dimension
@generated_jit(nopython=True)
def get_index(gc, x):
pass


@overload(get_index)
def ol_get_index(gc, x):
if gc == t_coord:
# regular coordinate
def fun(gc, x):
Expand All @@ -53,8 +57,12 @@ def fun(gc, x):


# returns number of dimension of a dimension
@generated_jit(nopython=True)
def get_size(gc):
pass


@overload(get_size)
def ol_get_size(gc):
if gc == t_coord:
# regular coordinate
def fun(gc):
Expand Down Expand Up @@ -145,8 +153,12 @@ def _map(*args):
# funzip(((1,2), (2,3), (4,3))) -> ((1,2,4),(2,3,3))


@generated_jit(nopython=True)
def funzip(t):
pass


@overload(funzip)
def ol_funzip(t):
k = t.count
assert len(set([e.count for e in t.types])) == 1
l = t.types[0].count
Expand All @@ -169,8 +181,12 @@ def print_tuple(t):
#####


@generated_jit(nopython=True)
def get_coeffs(X, I):
pass


@overload(get_coeffs)
def ol_get_coeffs(X, I):
if X.ndim > len(I):
print("not implemented yet")
else:
Expand Down Expand Up @@ -218,8 +234,12 @@ def gen_tensor_reduction(X, symbs, inds=[]):
return str.join(" + ", exprs)


@generated_jit(nopython=True)
def tensor_reduction(C, l):
pass


@overload(tensor_reduction)
def ol_tensor_reduction(C, l):
d = len(l.types)
ex = gen_tensor_reduction("C", ["l[{}]".format(i) for i in range(d)])
dd = dict()
Expand All @@ -228,8 +248,12 @@ def tensor_reduction(C, l):
return dd["tensor_reduction"]


@generated_jit(nopython=True)
def extract_row(a, n, tup):
pass


@overload(extract_row)
def ol_extract_row(a, n, tup):
d = len(tup.types)
dd = {}
s = "def extract_row(a, n, tup): return ({},)".format(
Expand All @@ -240,8 +264,12 @@ def extract_row(a, n, tup):


# find closest point inside the grid domain
@generated_jit
def project(grid, point):
pass


@overload(project)
def ol_project(grid, point):
s = "def __project(grid, point):\n"
d = len(grid.types)
for i in range(d):
Expand Down
15 changes: 10 additions & 5 deletions interpolation/multilinear/mlinterp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,24 @@
)

from numba import njit
from numba.extending import overload
from typing import Tuple

from ..compat import UniTuple, Tuple, Float, Integer, Array

Scalar = (Float, Integer)

import numpy as np
from numba import generated_jit

# logic of multilinear interpolation


@generated_jit
def mlinterp(grid, c, u):
pass


@overload(mlinterp)
def ol_mlinterp(grid, c, u):
if isinstance(u, UniTuple):

def mlininterp(grid: Tuple, c: Array, u: Tuple) -> float:
Expand Down Expand Up @@ -213,11 +217,12 @@ def {funname}(*args):
return source


from numba import generated_jit
def interp(*args):
pass


@generated_jit(nopython=True)
def interp(*args):
@overload(interp)
def ol_interp(*args):
aa = args[0].types

it = detect_types(aa)
Expand Down
18 changes: 12 additions & 6 deletions interpolation/multilinear/tests/test_multilinear.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,29 @@
from numpy import linspace, array
from numpy.random import random
from numba import typeof
from numba import njit

import numpy as np
from ..fungen import get_index


@njit
def get_index_njit(gc, x):
return get_index(gc, x)


def test_barycentric_indexes():
# irregular grid
gg = np.array([0.0, 1.0])
assert get_index(gg, -0.1) == (0, -0.1)
assert get_index(gg, 0.5) == (0, 0.5)
assert get_index(gg, 1.1) == (0, 1.1)
assert get_index_njit(gg, -0.1) == (0, -0.1)
assert get_index_njit(gg, 0.5) == (0, 0.5)
assert get_index_njit(gg, 1.1) == (0, 1.1)

# regular grid
gg = (0.0, 1.0, 2)
assert get_index(gg, -0.1) == (0, -0.1)
assert get_index(gg, 0.5) == (0, 0.5)
assert get_index(gg, 1.1) == (0, 1.1)
assert get_index_njit(gg, -0.1) == (0, -0.1)
assert get_index_njit(gg, 0.5) == (0, 0.5)
assert get_index_njit(gg, 1.1) == (0, 1.1)


# 2d-vecev-scalar
Expand Down
16 changes: 13 additions & 3 deletions interpolation/splines/eval_cubic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy

from numba import njit
from numba.extending import overload
from .eval_splines import eval_cubic

## the functions in this file provide backward compatibility calls
Expand All @@ -11,19 +13,27 @@
# Compatibility calls #
#######################

from numba import generated_jit
from .codegen import source_to_function


@generated_jit
def get_grid(a, b, n, C):
def _get_grid(a, b, n, C):
pass


@overload(_get_grid)
def ol_get_grid(a, b, n, C):
d = C.ndim
s = "({},)".format(str.join(", ", [f"(a[{k}],b[{k}],n[{k}])" for k in range(d)]))
txt = "def get_grid(a,b,n,C): return {}".format(s)
f = source_to_function(txt)
return f


@njit
def get_grid(a, b, n, C):
return _get_grid(a, b, n, C)


def eval_cubic_spline(a, b, orders, coefs, point):
"""Evaluates a cubic spline at one point
Expand Down
9 changes: 5 additions & 4 deletions interpolation/splines/eval_splines.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numpy as np
from numba import jit, generated_jit
from numpy import zeros
from numpy import floor

Expand All @@ -19,7 +18,6 @@
from interpolation.splines.codegen import get_code_spline, source_to_function
from numba.types import UniTuple, float64, Array
from interpolation.splines.codegen import source_to_function
from numba import generated_jit


from ..compat import Tuple, UniTuple
Expand Down Expand Up @@ -50,9 +48,12 @@
### eval spline (main function)


# @generated_jit(inline='always', nopython=True) # doens't work
@generated_jit(nopython=True)
def allocate_output(G, C, P, O):
pass


@overload(allocate_output)
def ol_allocate_output(G, C, P, O):
if C.ndim == len(G) + 1:
# vector valued
if P.ndim == 2:
Expand Down
7 changes: 5 additions & 2 deletions interpolation/splines/hermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def HermiteInterpolationVect(xvect, x: Vector, y: Vector, yp: Vector):

from numba import njit, types
from numba.extending import overload, register_jitable
from numba import generated_jit


def _hermite(x0, x, y, yp, out=None):
Expand All @@ -102,8 +101,12 @@ def _hermite(x0, x, y, yp, out=None):
from numba.core.types.misc import NoneType as none


@generated_jit
def hermite(x0, x, y, yp, out=None):
pass


@overload(hermite)
def ol_hermite(x0, x, y, yp, out=None):
try:
n = x0.ndim
if n == 1:
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ maintainers = [
license = "BSD-2-Clause"

[tool.poetry.dependencies]
python = ">=3.9, <=3.12"
numba = "^0.57"
python = ">=3.9"
numba = ">=0.57"
scipy = "^1.10"

[tool.poetry.dev-dependencies]
Expand Down

0 comments on commit f9f6f64

Please sign in to comment.