In [1]:
import torch
from numba import cuda
import pytorch_numba_extension_jit as pnex

In [2]:
arr = torch.ones(1_000_000, device="cuda")

In [3]:
def time_empty():
    pass

%timeit time_empty()

11.4 ns ± 0.0707 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


In [4]:
def time_sync():
    torch.cuda.synchronize()

%timeit time_sync()

2 μs ± 7.18 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [5]:
def time_torch():
    arr.add_(1.0)
    torch.cuda.synchronize()

time_torch()
%timeit time_torch()

7.06 μs ± 6.56 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [8]:
@cuda.jit("void(float32[:])")
def add_numba(a):
    x = cuda.grid(1)
    if x < a.size:
        a[x] += 1

n_blocks = (arr.numel() + 127) // 128
def time_numba():
    add_numba[n_blocks, 128](arr)
    torch.cuda.synchronize()

time_numba()
%timeit time_numba()

36.3 μs ± 76.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
@pnex.jit(n_threads=arr.numel(), cache_id="overhead_measure")
def add_extension(a: pnex.InMut("f32", (None,))):
    x = cuda.grid(1)
    if x < a.size:
        a[x] += 1

def time_extension():
    add_extension(arr)
    torch.cuda.synchronize()

time_extension()
%timeit time_extension()

13.1 μs ± 13.7 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
