In [1]:
import numba as nb
import numpy as np

from math import sin, cos

# PoC: New JIT decorator with different default behavior

The following example introduces a new decorator, `cjit`, which maintains all semantics of `numba.jit` and therefore remains completely compatible with it. The only difference is that it allows to set new default values, differing from those in `numba`, for arguments in one central place.

In [2]:
def cjit(*args, nopython = True, inline = 'always', **kwargs):
    
    if len(args) == 1 and callable(args[0]):
        func = args[0]
        args = tuple()
    else:
        func = None

    def wrapper(func):
        return nb.jit(
            *args,
            nopython = nopython,
            inline = inline,
            **kwargs,
        )(func)
    
    if func is not None:
        return wrapper(func)
        
    return wrapper

To conform that it is working:

In [3]:
@cjit
def foo(x, y):
    return x - y

@cjit('f8(f8,f8)')
def bar(x, y):
    return x - y

foo(7, 4), foo, bar(7, 4), bar

(3,
 CPUDispatcher(<function foo at 0x7f24dde5f7f0>),
 3.0,
 CPUDispatcher(<function bar at 0x7f24dde5e7a0>))

# Proposed central `poliastro._jit` module

Taking the previous idea one step further, one could introduce literally one central place for switching to different compiler backends, i.e. `cpu` (single-thread), `parallel` (cpu multi-thread) and `cuda` across the entire package. All sub-modules within `poliastro` would import those decorators instead of `jit` and friends from `numba` directly.

Change the value of `TARGET`, re-start the kernel (!) and re-run the notebook to see effects.

In [4]:
TARGET = 'cpu'  # os.environ.get('POLIASTRO_TARGET', 'cpu')
if TARGET not in ('cpu', 'parallel', 'cuda'):
    raise ValueError(f'unknown target "{TARGET:s}"')
if TARGET == 'cuda':
    from numba import cuda  # explicit import required and only performed if target is switched to cuda

NOPYTHON = True  # only for debugging, True by default

def hjit(*args, **kwargs):
    """
    Scalar helper, pre-configured, internal.
    Functions decorated by it can only be called directly if TARGET is cpu or parallel.
    """
    
    if len(args) == 1 and callable(args[0]):
        func = args[0]
        args = tuple()
    else:
        func = None

    def wrapper(func):
        
        cfg = {}
        if TARGET in ('cpu', 'parallel'):
            cfg.update({'nopython': NOPYTHON, 'inline': 'always'})
        if TARGET == 'cuda':
            cfg.update({'device': True, 'inline': True})
        cfg.update(kwargs)
        
        wjit = cuda.jit if TARGET == 'cuda' else nb.jit
        
        return wjit(
            *args,
            **cfg,
        )(func)
    
    if func is not None:
        return wrapper(func)
        
    return wrapper

def vjit(*args, **kwargs):
    """
    Vectorize on array, pre-configured, user-facing.
    Functions decorated by it can always be called directly if needed.
    """
    
    if len(args) == 1 and callable(args[0]):
        func = args[0]
        args = tuple()
    else:
        func = None

    def wrapper(func):
        
        cfg = {'target': TARGET}
        if TARGET in ('cpu', 'parallel'):
            cfg.update({'nopython': NOPYTHON})
        cfg.update(kwargs)
        
        return nb.vectorize(
            *args,
            **cfg,
        )(func)
    
    if func is not None:
        return wrapper(func)

    return wrapper

# Example usage

In [5]:
@hjit('f8(f8)')
def internal_on_scalar(scalar: float) -> float:
    res: float = 0.0
    for idx in range(round(scalar)):
        if idx % 2 == 0:
            res += sin(idx)
        else:
            res -= cos(idx)
    return res

@vjit('f8(f8)')
def user_facing_on_array(d: float) -> float:
    return internal_on_scalar(d)

data = np.arange(0, 100, 1, dtype = 'f8')
result = user_facing_on_array(data)

data, result

(array([ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.,
        13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23., 24., 25.,
        26., 27., 28., 29., 30., 31., 32., 33., 34., 35., 36., 37., 38.,
        39., 40., 41., 42., 43., 44., 45., 46., 47., 48., 49., 50., 51.,
        52., 53., 54., 55., 56., 57., 58., 59., 60., 61., 62., 63., 64.,
        65., 66., 67., 68., 69., 70., 71., 72., 73., 74., 75., 76., 77.,
        78., 79., 80., 81., 82., 83., 84., 85., 86., 87., 88., 89., 90.,
        91., 92., 93., 94., 95., 96., 97., 98., 99.]),
 array([ 0.00000000e+00,  0.00000000e+00, -5.40302306e-01,  3.68995121e-01,
         1.35898762e+00,  6.02185122e-01,  3.18522937e-01,  3.91074386e-02,
        -7.14794816e-01,  2.74563431e-01,  1.18569369e+00,  6.41672582e-01,
         6.37246884e-01,  1.00673966e-01, -8.06772816e-01,  1.83834540e-01,
         9.43522453e-01,  6.55619136e-01,  9.30782474e-01,  1.79795228e-01,
        -8.08909391e-01,  1.04035860e-01,  6.51765120e