# numba的使用

## 安装

    pip install numba tbb

## jit与njit

主要用于循环，很多场景下一个@njit（等价于@jit(nopython=True)）装饰符就可以有很好的效果

In [None]:
import numpy
from numba import jit, njit
from numba.typed import List

# @jit(nopython=True)
@njit
def do_trig(x, y):
    z = numpy.empty_like(x)
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            z[i, j] = numpy.sin(x[i, j]**2) + numpy.cos(y[i, j])
    return z

x = numpy.random.random((1000, 1000))
y = numpy.random.random((1000, 1000))

%timeit do_trig(x, y)

可以在jit中指定输入输出参数类型，支持的类型见https://numba.readthedocs.io/en/stable/reference/types.html#basic-types 。注意jit会“预编译”你的代码，这意味着它会在某种程度上固定住各个变量的数据类型；所以使用jit加速时，如果想要使用的是float数组的话，就不能用\[0\]\*len(x)定义、而应该用：\[0.\]\*len(x)

```python
    @numba.jit([
        "void(int64, int64, int64, int64, float32[:,:,:,:],"
        "int64, int64, int64, float32[:,:,:,:], float32[:,:,:,:])"
    ], nopython=True)
    def conv_bp(n, n_filters, out_h, out_w, dx_padded,
            filter_height, filter_width, sd, inner_weight, delta):
```

In [None]:
from numba import njit, int32
 
@njit(int32(int32, int32))
def f_t(x, y):
    return x + y

通过enable并行可以实现进一步加速。还有一些其他选项，如fastmath=True牺牲少量精度提升速度，多线程下通过@jit(nogil=True)绕开GIL锁等，需根据实际情况和测试结果选用

In [None]:
import numba
from numba import prange # parallel=True时，循环内range要改成prange使得numba能够识别

numba.config.NUMBA_DEFAULT_NUM_THREADS=4 # 并行度不是越高越好，通常不超过CPU核数

@njit(parallel=True,fastmath=True) # 进一步优化并行
def do_trig_p(x, y):
    z = numpy.empty_like(x)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            z[i, j] = numpy.sin(x[i, j]**2) + numpy.cos(y[i, j])
    return z

%timeit do_trig_p(x, y)

## vectorize

@vectorize可以加速矩阵运算，类似numpy的ufunc广播，并行目标可以是cpu,parallel或者cuda，其选择建议见下：

>>> A general guideline is to choose different targets for different data sizes and algorithms. The “cpu” target works well for small data sizes (approx. less than 1KB) and low compute intensity algorithms. It has the least amount of overhead. The “parallel” target works well for medium data sizes (approx. less than 1MB). Threading adds a small delay. The “cuda” target works well for big data sizes (approx. greater than 1MB) and high compute intensity algorithms. Transfering memory to and from the GPU adds significant overhead

In [None]:
import math
from numba import vectorize

# @vectorize
@vectorize('float64(float64, float64)', target='cuda') #指定类型可以进一步加速，某些类型可能需要从numba里import
def do_trig_vec_par(x, y):
    z = math.sin(x**2) + math.cos(y)
    return z

%timeit do_trig_vec_par(x, y)

## 优化tips

需要特别注意的是，使用jit和使用纯numpy进行编程的很大一点不同就是，不要畏惧用for；事实上一般来说，代码“长得越像 C”、速度就会越快

In [None]:
import numba as nb
import numpy as np

# 普通的 MaxPool
def max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

# 简单地加了个 jit 后的 MaxPool
@nb.jit(nopython=True)
def jit_max_pool_kernel(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    window = x[i, j, p:p+pool_height, q:q+pool_width]
                    rs[i, j, p, q] += np.max(window)

# 不惧用 for 的、“更像 C”的 MaxPool
@nb.jit(nopython=True)
def jit_max_pool_kernel2(x, rs, *args):
    n, n_channels, pool_height, pool_width, out_h, out_w = args
    for i in range(n):
        for j in range(n_channels):
            for p in range(out_h):
                for q in range(out_w):
                    _max = x[i, j, p, q]
                    for r in range(pool_height):
                        for s in range(pool_width):
                            _tmp = x[i, j, p+r, q+s]
                            if _tmp > _max:
                                _max = _tmp
                    rs[i, j, p, q] += _max

# MaxPool 的封装
def max_pool(x, kernel, args):
    n, n_channels = args[:2]
    out_h, out_w = args[-2:]
    rs = np.zeros([n, n_filters, out_h, out_w], dtype=np.float32)
    kernel(x, rs, *args)
    return rs

# 64 个 3 x 28 x 28 的图像输入（模拟 mnist）
x = np.random.randn(64, 3, 28, 28).astype(np.float32)
# 16 个 5 x 5 的 kernel
w = np.random.randn(16, x.shape[1], 5, 5).astype(np.float32)

n, n_channels, height, width = x.shape
n_filters, _, filter_height, filter_width = w.shape
out_h = height - filter_height + 1
out_w = width - filter_width + 1

pool_height, pool_width = 2, 2
n, n_channels, height, width = x.shape
out_h = height - pool_height + 1
out_w = width - pool_width + 1
args = (n, n_channels, pool_height, pool_width, out_h, out_w)

assert np.allclose(max_pool(x, max_pool_kernel, args), max_pool(x, jit_max_pool_kernel, args))
assert np.allclose(max_pool(x, jit_max_pool_kernel, args), max_pool(x, jit_max_pool_kernel2, args))
%timeit max_pool(x, max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel, args)
%timeit max_pool(x, jit_max_pool_kernel2, args)

## numba和numpy的结构体

numba对于numpy的支持是最完善的，对于python中的list、dict、tuple等数据类型要不就是不支持优化要不就是支持优化但是使用存在一定的局限性，所以比较建议尽量把输入用numpy的方式表示起来。

有时候为了方便，我们希望numba所修饰的函数能够接受结构体类型的参数该怎么办，因为python中没有显式的struct函数，只有class函数可以用来充当struct来用。

In [None]:
import numpy
from numba import njit

class Point():
    """    
    Arguments:
        domain: the domain of random generated coordinates x,y,z, 
                default=1.0
    
    Attributes:
        x, y, z: coordinates of the point
    """
    def __init__(self, domain=1.0):
        self.x = domain * numpy.random.random()
        self.y = domain * numpy.random.random()
        self.z = domain * numpy.random.random()
            
    def distance(self, other):
        return ((self.x - other.x)**2 + 
                (self.y - other.y)**2 + 
                (self.z - other.z)**2)**.5

class Particle(Point):
    """    
    Attributes:
        m: mass of the particle
        phi: the potential of the particle
    """
    
    def __init__(self, domain=1.0, m=1.0):
        Point.__init__(self, domain) ##父类的初始化否则的话一般是 Particle.__inti__()
        self.m = m
        self.phi = 0.


这里我们就定义了一个叫Particle的对象，直接调用他的属性就会产生类似结构体的功能了。接下来我们产生1000个这样的结构体表示1000个数据集，然后放入普通函数中计算

In [None]:
n = 1000
particles = [Particle(m = 1 / n) for i in range(n)]
def direct_sum(particles):
    """
    Calculate the potential at each particle
    using direct summation method.

    Arguments:
        particles: the list of particles

    """
    for i, target in enumerate(particles):
        for source in (particles[:i] + particles[i+1:]):
            r = target.distance(source)
            target.phi += source.m / r

orig_time = %timeit -o direct_sum(particles)

该函数如果用@jit优化会报错，因为numba无法识别python中的类class。那么如何解决这个问题？

其实，numpy有一个很有意思的功能可以用来实现类似结构体的功能，而且调用的效率要比通过class来定义的结构体高太多。

下面定义了一个numpy的数据类型“particle_dtype”，类似于c中的struct，首先是组员的名字分为：'x','y','z','m','phi'，然后通过formats定义每一个组员的数据类型。

In [None]:
particle_dtype = numpy.dtype({'names':['x','y','z','m','phi'], 
                             'formats':[numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double, 
                                        numpy.double]})

接下来我们把上述的结构体定义实例化

In [None]:
def create_n_random_particles(n, m, domain=1):
    '''
    Creates `n` particles with mass `m` with random coordinates
    between 0 and `domain`
    '''
    parts = numpy.zeros((n), dtype=particle_dtype)
    
    parts['x'] = numpy.random.random(size=n) * domain
    parts['y'] = numpy.random.random(size=n) * domain
    parts['z'] = numpy.random.random(size=n) * domain
    parts['m'] = m
    parts['phi'] = 0.0

    return parts
parts = create_n_random_particles(1000, .001, 1)
print(parts[:5])

现在可以用jit优化了

In [None]:
@njit
def distance(self, other):
    return ((self.x - other.x)**2 + 
            (self.y - other.y)**2 + 
            (self.z - other.z)**2)**.5


@njit
def direct_sum_jit(particles):
    for i, target in enumerate(particles):
        for j in range(particles.shape[0]):
            if i == j:
                continue
            source = particles[j]
            r = distance(target, source)
            target['phi'] += source['m'] / r
%timeit direct_sum_jit(parts)