# Numba XND

 The Numba integration for xnd let's you run compile XND python code with Numba:

In [None]:
from numba import njit
from xnd import xnd
import numba_xnd

In [None]:
def mean(x):
    n = x.type.shape[0]
    s = 0
    for i in range(n):
        s += x[i].value
    return s / n

In [None]:
x = xnd(list(range(100000)))

In [None]:
mean(x)

In [None]:
%%timeit
mean(x)

In [None]:
mean_jitted = njit(mean)

In [None]:
mean_jitted(x)

In [None]:
%%timeit
mean_jitted(x)

We can also compile Gumath kernel's with Numba, so that they are broadacst. In this case

In [None]:
@numba_xnd.gumath.register_kernel("... * N * int64 -> ... * float64")
def mean_kernel(x, res):
    n = x.type.shape[0]
    res[()] = 0
    for i in range(n):
        res[()] = res.value + x[i].value
    res[()] = res.value / n

In [None]:
mean_kernel(x)

In [None]:
%%timeit
mean_kernel(x)

In [None]:
mean_kernel(xnd([[1, 2, 3], [3, 4, 5]]))

## Matrix Multiply

In [None]:
@numba_xnd.gumath.register_kernel(
    [
        "... * N * M * int64, ... * M * K * int64 -> ... * N * K * int64",
        "... * N * M * float64, ... * M * K * float64 -> ... * N * K * float64",
    ]
)
def simple_matrix_multiply(a, b, c):
    n, m = a.type.shape
    m_, p = b.type.shape
    for i in range(n):
        for j in range(p):
            c[i, j] = 0
            for k in range(m):
                c[i, j] = c[i, j].value + a[i, k].value * b[k, j].value

In [None]:
simple_matrix_multiply(xnd([[1, 2], [3, 4]]), xnd([[1, 2], [3, 4]]))

Also works with broadcasting

In [None]:
x = xnd([[[i + j + k for i in range(10)] for j in range(10)] for k in range(10)])
x

In [None]:
import gumath.functions

In [None]:
y = gumath.functions.add(x, xnd(1))

In [None]:
simple_matrix_multiply(x, y)