In [1]:
from numba import njit, jitclass, prange
import numba
import numpy as np

### How to nest numba jitclass
https://stackoverflow.com/questions/38682260/how-to-nest-numba-jitclass

### Make it possible to annotate method from a jitclass (e.g. parallel = True)
https://github.com/numba/numba/issues/3417

### __init__ is mandatory for `jitclass`
https://github.com/numba/numba/issues/2167

In [2]:
# define log utility class

@jitclass([])
class LogUtility(object):

    # __init__ is mandatory in `jitclass`
    def __init__(self):
        None

    def U(self, a):
        return np.log(a)

    def dU(self, a):
        return 1/a
    
# compile the method beforehand
@njit
def inverse(a):
    return 1/a

@njit
def jitted_log(a):
    return np.log(a)

@jitclass([])
class LogUtility_jitted(object):

    # __init__ is mandatory in `jitclass`
    def __init__(self):
        None

    def U(self, a):
        return jitted_log(a)

    def dU(self, a):
        return inverse(a)

In [3]:
u = LogUtility()
u.U(2.), u.dU(2.)

(0.6931471805599453, 0.5)

In [4]:
u_jitted = LogUtility_jitted()
u_jitted.U(2.), u_jitted.dU(2.)

(0.6931471805599453, 0.5)

In [5]:
%timeit u.U(2.)

900 ns ± 13 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [6]:
%timeit u_jitted.U(2.)

924 ns ± 15.5 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [7]:
%timeit u.dU(2.)

886 ns ± 29.2 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [8]:
%timeit u_jitted.dU(2.)

902 ns ± 26.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


### There is no difference between using njit or unjitted function in a njit-class!
### Methods of a njit class will be compiled as njit function.

In [9]:
# get the type of LogUtility class
utility_type = numba.deferred_type()
utility_type.define(LogUtility.class_type.instance_type)

In [10]:
nb_int, nb_float = numba.typeof(0), numba.typeof(0.)

spec = [
    ('a', nb_int),               # a simple scalar field
    ('b', nb_float),          # an array field
    ('utility', utility_type)
]

@jitclass(spec)
class model(object):
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.utility = c

    def create(self):
        return np.arange(self.a, self.b)

In [11]:
a = model(1, 4., u)

In [12]:
a.utility.U(2.), a.utility.dU(2.)

(0.6931471805599453, 0.5)

In [13]:
a.create()

array([1., 2., 3.])

In [14]:
# define a njit function that takes jitclass as argument

@njit
def f(model):
    
    print(model.a, model.b)
    print(model.utility.U(2.), model.utility.dU(2.))
    print(model.create())

In [15]:
f(a)

1 4.0
0.6931471805599453 0.5
[1. 2. 3.]
