# Imports

In [1]:
from numba import guvectorize, jit, cuda
import numpy as np

# Jit-wrapper for cross-compilation, simplified

In [2]:
TARGET = 'cuda'
assert TARGET in ('cpu', 'parallel', 'cuda')

if TARGET == 'cuda':  # select decorator and configuration for "helper" function(s)
    hjit = cuda.jit
    hkwargs = dict(device = True, inline = True)
else:
    hjit = jit
    hkwargs = dict(nopython = True, inline = 'always')

def _parse_signature(s):  # simplify complicated signatures (tuples because, well, returning arrays does not work for device functions)
    s = s.replace('M', 'Tuple([V,V,V])')
    s = s.replace('V', 'Tuple([f,f,f])')
    return s.replace('f', 'f8')

# Demo helper function

If the target is `cuda`, it's a device function.

In [3]:
@hjit(_parse_signature('V(V,M)'), **hkwargs)
def matmul_VM_(a, b):
    return (
        a[0] * b[0][0] + a[1] * b[1][0] + a[2] * b[2][0],
        a[0] * b[0][1] + a[1] * b[1][1] + a[2] * b[2][1],
        a[0] * b[0][2] + a[1] * b[1][2] + a[2] * b[2][2],
    )

# Demo generalized universal function

Works for targets `cpu` and `parallel`, breaks down on `cuda`.

In [4]:
@guvectorize(
    _parse_signature('void(f[:],f[:],f[:],f[:],f[:],f[:])'),
    '(n),(n),(n)->(n),(n),(n)',  # For target `cuda`: 'AssertionError: only support 1 output'
    # '(n),(n),(n)->(3,n)',  # For target `cuda`: 'ValueError: bad token in signature "3"'
    target = TARGET,
    nopython = True,
)
def foo(a, b, c, x, y, z):
    R = (
        (0.2, 0.8, 0.3),
        (0.3, 0.5, 0.6),
        (0.4, 0.1, 0.8),
    )
    for idx in range(a.shape[0]):
        x[idx], y[idx], z[idx] = matmul_VM_((a[idx], b[idx], c[idx]), R)

AssertionError: only support 1 output

# Demo usage

Works for targets `cpu` and `parallel` - so far.

In [None]:
LEN = 100_000_000

data = np.arange(0, 3 * LEN, dtype = 'f8').reshape(3, LEN)
res = np.zeros_like(data)

foo(data[0], data[1], data[2], res[0], res[1], res[2])
res