# Examples

Matrix multiplication

In [1]:
import numba_xnd
from xnd import xnd

In [2]:
@numba_xnd.gumath.register_kernel(
    [
        "... * N * M * int64, ... * M * K * int64 -> ... * N * K * int64",
        "... * N * M * float64, ... * M * K * float64 -> ... * N * K * float64",
    ]
)
def simple_matrix_multiply(a, b, c):
    n, m = a.type.shape
    m_, p = b.type.shape
    for i in range(n):
        for j in range(p):
            c[i, j] = 0
            for k in range(m):
                c[i, j] = c[i, j].value + a[i, k].value * b[k, j].value

In [3]:
simple_matrix_multiply(xnd([[1, 2], [3, 4]]), xnd([[1, 2], [3, 4]]))

xnd([[7, 10], [15, 22]], type='2 * 2 * int64')

Also works with broadcasting

In [4]:
x = xnd([[[i + j + k for i in range(10)] for j in range(10)] for k in range(10)])
x

xnd([[[0, 1, 2, 3, 4, 5, 6, 7, 8, ...],
      [1, 2, 3, 4, 5, 6, 7, 8, 9, ...],
      [2, 3, 4, 5, 6, 7, 8, 9, 10, ...],
      [3, 4, 5, 6, 7, 8, 9, 10, 11, ...],
      [4, 5, 6, 7, 8, 9, 10, 11, 12, ...],
      [5, 6, 7, 8, 9, 10, 11, 12, 13, ...],
      [6, 7, 8, 9, 10, 11, 12, 13, 14, ...],
      [7, 8, 9, 10, 11, 12, 13, 14, 15, ...],
      [8, 9, 10, 11, 12, 13, 14, 15, 16, ...],
      ...],
     [[1, 2, 3, 4, 5, 6, 7, 8, 9, ...],
      [2, 3, 4, 5, 6, 7, 8, 9, 10, ...],
      [3, 4, 5, 6, 7, 8, 9, 10, 11, ...],
      [4, 5, 6, 7, 8, 9, 10, 11, 12, ...],
      [5, 6, 7, 8, 9, 10, 11, 12, 13, ...],
      [6, 7, 8, 9, 10, 11, 12, 13, 14, ...],
      [7, 8, 9, 10, 11, 12, 13, 14, 15, ...],
      [8, 9, 10, 11, 12, 13, 14, 15, 16, ...],
      [9, 10, 11, 12, 13, 14, 15, 16, 17, ...],
      ...],
     [[2, 3, 4, 5, 6, 7, 8, 9, 10, ...],
      [3, 4, 5, 6, 7, 8, 9, 10, 11, ...],
      [4, 5, 6, 7, 8, 9, 10, 11, 12, ...],
      [5, 6, 7, 8, 9, 10, 11, 12, 13, ...],
      [6, 7, 8, 9, 10,

In [5]:
import gumath.functions

In [6]:
y = gumath.functions.add(x, xnd(1))

In [7]:
simple_matrix_multiply(x, y)

xnd([[[330, 375, 420, 465, 510, 555, 600, 645, 690, ...],
      [385, 440, 495, 550, 605, 660, 715, 770, 825, ...],
      [440, 505, 570, 635, 700, 765, 830, 895, 960, ...],
      [495, 570, 645, 720, 795, 870, 945, 1020, 1095, ...],
      [550, 635, 720, 805, 890, 975, 1060, 1145, 1230, ...],
      [605, 700, 795, 890, 985, 1080, 1175, 1270, 1365, ...],
      [660, 765, 870, 975, 1080, 1185, 1290, 1395, 1500, ...],
      [715, 830, 945, 1060, 1175, 1290, 1405, 1520, 1635, ...],
      [770, 895, 1020, 1145, 1270, 1395, 1520, 1645, 1770, ...],
      ...],
     [[440, 495, 550, 605, 660, 715, 770, 825, 880, ...],
      [505, 570, 635, 700, 765, 830, 895, 960, 1025, ...],
      [570, 645, 720, 795, 870, 945, 1020, 1095, 1170, ...],
      [635, 720, 805, 890, 975, 1060, 1145, 1230, 1315, ...],
      [700, 795, 890, 985, 1080, 1175, 1270, 1365, 1460, ...],
      [765, 870, 975, 1080, 1185, 1290, 1395, 1500, 1605, ...],
      [830, 945, 1060, 1175, 1290, 1405, 1520, 1635, 1750, ...],
      [