In [1]:
!uv sync --quiet
import torch
from pytorch_semifield_conv import SelectSemifield
import math
import numba
from numba import cuda

In [2]:
from src import load_data
k_mnist = load_data.k_mnist()
ex_data = k_mnist.x_train[:1024].cuda().repeat((1, 6, 1, 1))

In [3]:
torch.manual_seed(0)
ex_kernel = torch.rand((6, 1, 11, 11), device="cuda")

In [4]:
op = SelectSemifield.tropical_max().lazy_fixed(thread_block_size=256)
op

CompiledConvFixedLazy()

In [5]:
opc = torch.compile(op, fullgraph=True)
opc(ex_data, ex_kernel, groups=6, padding="same", stride=2).shape

torch.Size([1024, 6, 14, 14])

In [7]:
g_inp = ex_data.clone().requires_grad_(True)
g_krn = ex_kernel.clone().requires_grad_(True)
g_tangent = torch.randn_like(op(ex_data, ex_kernel, groups=6, padding="same", stride=2))

def run_one():
    res = op(g_inp, g_krn)
    res.backward(g_tangent)
    torch.cuda.synchronize()

run_one()

%timeit run_one()

256 μs ± 200 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
g_inp = ex_data.clone().requires_grad_(True)
g_krn = ex_kernel.clone().requires_grad_(True)
g_tangent = torch.randn_like(op(ex_data, ex_kernel, groups=6, padding="same", stride=2))

def run_one():
    res = opc(g_inp, g_krn)
    res.backward(g_tangent)
    torch.cuda.synchronize()

run_one()

%timeit run_one()

In [33]:
static_inp = torch.arange(5, dtype=torch.float32, device="cuda")
static_out = torch.empty((), device="cuda")
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(3):
        a = torch.zeros(5, device="cuda")
        a += static_inp
        static_out[...] = a[2]
        print(torch.cuda.current_stream().stream_id)

g = torch.cuda.CUDAGraph()
# Sets grads to None before capture, so backward() will create
# .grad attributes with allocations from the graph's private pool
with torch.cuda.graph(g):
    a = torch.zeros(5, device="cuda")
    a += static_inp
    static_out[...] = a[2]

torch.cuda.current_stream().wait_stream(s)

195
195
195


In [31]:
print(static_out)
static_inp[2] = 5
g.replay()
print(static_out)

tensor(2., device='cuda:0')
tensor(5., device='cuda:0')


In [9]:
opg = torch.compile(op, fullgraph=True, backend="cudagraphs")
torch.compiler.cudagraph_mark_step_begin()
opg(ex_data+1, ex_kernel, groups=6, padding="same", stride=2)
torch.compiler.cudagraph_mark_step_begin()
opg(ex_data+2, ex_kernel, groups=6, padding="same", stride=2)[0, 0, :3, :3].numpy(force=True)

array([[2.9608638, 2.9926705, 3.3990464],
       [2.9608638, 3.2029681, 3.89892  ],
       [2.9608638, 3.5723448, 3.9608638]], dtype=float32)

In [None]:
#pre-expansion: 508 μs ± 160 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
#post-unfloat 256 μs ± 200 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
#post-expansion 168 μs ± 123 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [4]:
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def copy_test(vals, out):
    x = cuda.grid(1)
    if x > vals.size:
        return

    out[x] = vals[x]

v = torch.arange(1024 * 6 * 28 * 28, dtype=torch.float32, device="cuda")
o = torch.zeros_like(v)
copy_test[v.numel() // 32, 32](v, o)
print(o)

tensor([0.0000e+00, 1.0000e+00, 2.0000e+00,  ..., 4.8169e+06, 4.8169e+06,
        4.8169e+06], device='cuda:0')


In [10]:
def run_one():
    copy_test[v.numel() // 32, 32](v, o)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

121 μs ± 142 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [16]:
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def copy_nth(vals, out):
    x = cuda.grid(1)
    if x > vals.size:
        return

    # out[7] = 7

v = torch.arange(1024 * 6 * 28 * 28, dtype=torch.float32, device="cuda")
o = torch.zeros_like(v)
copy_nth[v.numel() // 32, 32](v, o)
print(o)

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')


In [17]:
def run_one():
    copy_nth[v.numel() // 32, 32](v, o)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

121 μs ± 179 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [2]:
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def bwd(vals, out):
    x = cuda.grid(1)
    if x > vals.size:
        return

    ox = x % 8
    cuda.atomic.add(out, ox, vals[x])

v = torch.arange(1024 * 6 * 28 * 28, dtype=torch.float32, device="cuda")
o = torch.zeros(8, device="cuda")
bwd[v.numel() // 32, 32](v, o)
print(o)

tensor([1.4501e+12, 1.4501e+12, 1.4501e+12, 1.4501e+12, 1.4501e+12, 1.4501e+12,
        1.4501e+12, 1.4501e+12], device='cuda:0')


In [20]:
EXPAND_SIZE = 1024 * 6
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:, :])")
def bwd_exp(vals, out):
    x = cuda.grid(1)
    if x > vals.size:
        return

    ox = x % 8
    part = x % EXPAND_SIZE
    cuda.atomic.add(out, (ox, part), vals[x])

v = torch.arange(1024 * 6 * 28 * 28, dtype=torch.float32, device="cuda")
o_exp = torch.zeros((8, EXPAND_SIZE), device="cuda")

bwd_exp[v.numel() // 32, 32](v, o_exp)
print(o_exp.sum(1))

tensor([1.4502e+12, 1.4502e+12, 1.4502e+12, 1.4502e+12, 1.4502e+12, 1.4502e+12,
        1.4502e+12, 1.4502e+12], device='cuda:0')


In [7]:
o = torch.zeros(8, device="cuda")
def run_one():
    bwd[v.numel() // 32, 32](v, o)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

772 μs ± 168 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [21]:
o_exp = torch.zeros((8, EXPAND_SIZE), device="cuda")
def run_one():
    bwd_exp[v.numel() // 32, 32](v, o_exp)
    o_exp.sum(1)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

130 μs ± 161 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [3]:
STRIDE = 2
WINDOW_SIZE = 6

In [4]:
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def pool_basic(vals, out):
    idx = cuda.grid(1)
    if idx > out.size:
        return

    ox = idx

    begin_x = ox * STRIDE
    acc = numba.float32(-100.0)
    for x in range(begin_x, begin_x + WINDOW_SIZE):
        if x >= vals.size:
            continue
        val = vals[x]
        val_2 = val + numba.float32(1.0)
        if val_2 > acc:
            acc = val_2

    out[ox] = acc

In [50]:
def run_one():
    vec_copy4[2**21 // 32, 32](a, b)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

82.2 μs ± 155 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
# 59.7 μs ± 69.9 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [36]:
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def pool_basic_select(vals, out):
    ox = cuda.grid(1)
    if ox > out.size:
        return


    acc = numba.float32(-100.0)

    for _ in range(WINDOW_SIZE):
        begin_x = ox * STRIDE
        for x in range(begin_x, begin_x + WINDOW_SIZE):
            val = cuda.selp(x < vals.size, vals[x], numba.float32(-1000))
            val_2 = val + numba.float32(1.0)
            acc = cuda.selp(val_2 > acc, val_2, acc)

    out[ox] = acc

IN_SIZE = 21
OUT_SIZE = (IN_SIZE - 1 - (WINDOW_SIZE - 1)) // STRIDE + 1

torch.manual_seed(0)
ex_vals_1d = torch.rand(IN_SIZE, device="cuda")
ex_out = torch.empty(OUT_SIZE, device="cuda")
print(ex_vals_1d)
ex_chec = torch.max_pool1d(ex_vals_1d.unsqueeze(0), WINDOW_SIZE, STRIDE).add(1).squeeze()
print(ex_chec)
pool_basic_select[(OUT_SIZE*STRIDE + 31) // 32, 32](ex_vals_1d, ex_out)
print(ex_out)
torch.testing.assert_close(ex_chec, ex_out)

tensor([0.3990, 0.5167, 0.0249, 0.9401, 0.9459, 0.7967, 0.4150, 0.8203, 0.2290,
        0.9096, 0.1183, 0.0752, 0.4092, 0.9601, 0.2093, 0.1940, 0.8909, 0.4387,
        0.3570, 0.5454, 0.8299], device='cuda:0')
tensor([1.9459, 1.9459, 1.9459, 1.9096, 1.9601, 1.9601, 1.9601, 1.8909],
       device='cuda:0')
tensor([1.9459, 1.9459, 1.9459, 1.9096, 1.9601, 1.9601, 1.9601, 1.8909],
       device='cuda:0')


In [40]:
BLOCK_SIZE = 64
CACHE_SIZE = (BLOCK_SIZE - 1) * STRIDE + WINDOW_SIZE
FILL_STEPS = math.ceil(CACHE_SIZE / BLOCK_SIZE)
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def pool_shmm(inp, out):
    ox = cuda.grid(1)

    inp_cache = cuda.shared.array(CACHE_SIZE, numba.float32)
    acc = numba.float32(-100.0)

    cache_pos = cuda.threadIdx.x
    block_begin_x = cuda.blockIdx.x * BLOCK_SIZE * STRIDE
    for _ in range(FILL_STEPS - 1):
        x = block_begin_x + cache_pos
        inp_cache[cache_pos] = cuda.selp(x < inp.size, inp[x], numba.float32(-200))

        cache_pos += BLOCK_SIZE

    if cache_pos < CACHE_SIZE:
        x = block_begin_x + cache_pos
        inp_cache[cache_pos] = cuda.selp(x < inp.size, inp[x], numba.float32(-200))

    cuda.syncthreads()

    if ox >= out.shape[-1]:
        return

    for _ in range(100):
        cache_begin_x = cuda.threadIdx.x * STRIDE
        for cache_x in range(cache_begin_x, cache_begin_x + WINDOW_SIZE):
            val = inp_cache[cache_x]
            val_2 = val + numba.float32(1.0)
            acc = cuda.selp(val_2 > acc, val_2, acc)

    out[ox] = acc

IN_SIZE = 134
OUT_SIZE = (IN_SIZE - 1 - (WINDOW_SIZE - 1)) // STRIDE + 1

torch.manual_seed(0)
ex_vals_1d = torch.rand(IN_SIZE, device="cuda")
ex_out = torch.zeros(OUT_SIZE, device="cuda")
print(ex_vals_1d)
ex_chec = torch.max_pool1d(ex_vals_1d.unsqueeze(0), WINDOW_SIZE, STRIDE).add(1).squeeze()
print(ex_chec)
pool_shmm[(OUT_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE, BLOCK_SIZE](ex_vals_1d, ex_out)
print(ex_out)
print(OUT_SIZE)
torch.testing.assert_close(ex_chec, ex_out)

tensor([0.3990, 0.5167, 0.0249, 0.9401, 0.9459, 0.7967, 0.4150, 0.8203, 0.2290,
        0.9096, 0.1183, 0.0752, 0.4092, 0.9601, 0.2093, 0.1940, 0.8909, 0.4387,
        0.3570, 0.5454, 0.8299, 0.2099, 0.7684, 0.4290, 0.2117, 0.6606, 0.1654,
        0.4250, 0.9927, 0.6964, 0.2472, 0.7028, 0.7494, 0.9303, 0.0494, 0.0750,
        0.7223, 0.9478, 0.3647, 0.2215, 0.7784, 0.6391, 0.2077, 0.7045, 0.9609,
        0.0594, 0.3358, 0.0616, 0.7030, 0.5642, 0.0102, 0.8551, 0.5187, 0.5017,
        0.1144, 0.2751, 0.5339, 0.8582, 0.8465, 0.1845, 0.6360, 0.6799, 0.4408,
        0.5010, 0.8097, 0.5962, 0.5514, 0.4169, 0.2961, 0.6828, 0.4728, 0.4230,
        0.7024, 0.7056, 0.5943, 0.1731, 0.6248, 0.6187, 0.7183, 0.8779, 0.9462,
        0.4853, 0.0058, 0.9289, 0.7312, 0.3061, 0.9718, 0.7474, 0.2582, 0.3683,
        0.6239, 0.0971, 0.8172, 0.6601, 0.3544, 0.5687, 0.8948, 0.2423, 0.6391,
        0.7883, 0.6621, 0.8817, 0.7484, 0.3815, 0.0312, 0.5583, 0.7475, 0.6970,
        0.8021, 0.5907, 0.0841, 0.7754, 

In [61]:
import numpy as np

In [63]:
FULL_MASK = 0xFFFF
CONST_LIST = np.array([1, 2, 3], dtype=np.uint64)
# noinspection PyArgumentList
@cuda.jit("void(float32[:], float32[:])")
def pool_strided(vals, out):
    idx = cuda.grid(1)
    if idx > out.size * STRIDE:
        return

    ox, stride_step = divmod(idx, STRIDE)
    stride_leader = stride_step == 0
    begin_x = ox * STRIDE
    acc = numba.float32(-100.0)
    for i in range(math.ceil(WINDOW_SIZE / STRIDE)):
        step = stride_step + i * STRIDE
        window_valid = step < WINDOW_SIZE
        x = begin_x + step
        x_valid = x < vals.shape[-1]
        val = cuda.selp(window_valid and x_valid, vals[x], numba.float32(-1000))
        val_2 = val + numba.float32(1.0)
        other_val_2 = cuda.shfl_down_sync(FULL_MASK, val_2, 1)

        if other_val_2 > val_2:
            val_2 = other_val_2

        if val_2 > acc:
            acc = val_2

    if stride_leader:
        out[ox] = acc

IN_SIZE = 21
OUT_SIZE = (IN_SIZE - 1 - (WINDOW_SIZE - 1)) // STRIDE + 1

torch.manual_seed(0)
ex_vals_1d = torch.rand(IN_SIZE, device="cuda")
ex_out = torch.empty(OUT_SIZE, device="cuda")
print(ex_vals_1d)
ex_chec = torch.max_pool1d(ex_vals_1d.unsqueeze(0), WINDOW_SIZE, STRIDE).add(1).squeeze()
print(ex_chec)
pool_strided[(OUT_SIZE*STRIDE + 31) // 32, 32](ex_vals_1d, ex_out)
print(ex_out)
torch.testing.assert_close(ex_chec, ex_out)

tensor([0.3990, 0.5167, 0.0249, 0.9401, 0.9459, 0.7967, 0.4150, 0.8203, 0.2290,
        0.9096, 0.1183, 0.0752, 0.4092, 0.9601, 0.2093, 0.1940, 0.8909, 0.4387,
        0.3570, 0.5454, 0.8299], device='cuda:0')
tensor([1.9459, 1.9459, 1.9459, 1.9096, 1.9601, 1.9601, 1.9601, 1.8909],
       device='cuda:0')
tensor([1.9459, 1.9459, 1.9459, 1.9096, 1.9601, 1.9601, 1.9601, 1.8909],
       device='cuda:0')


In [41]:
IN_SIZE = 16_000_000
OUT_SIZE = (IN_SIZE - 1 - (WINDOW_SIZE - 1)) // STRIDE + 1

pool = pool_shmm
block_size = 64
# n_blocks = (OUT_SIZE + 31) // 32
# n_blocks = (OUT_SIZE * STRIDE + 31) // 32
n_blocks = (OUT_SIZE + block_size - 1) // block_size

torch.manual_seed(0)
lg_vals_1d = torch.rand(IN_SIZE, device="cuda")
lg_out = torch.empty(OUT_SIZE, device="cuda")
check_out = torch.max_pool1d(lg_vals_1d.unsqueeze(0), WINDOW_SIZE, STRIDE).add(1).squeeze()
pool[n_blocks, block_size](lg_vals_1d, lg_out)
torch.testing.assert_close(check_out, lg_out)

def run_one():
    pool[n_blocks, block_size](lg_vals_1d, lg_out)
    torch.cuda.synchronize()

run_one()
%timeit run_one()

316 μs ± 228 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# Basic (+select): 110 μs
# shmem 108?

