Skip to content

Commit

Permalink
feat: parallel-scan algorithm for first-order filter (#11)
Browse files Browse the repository at this point in the history
* feat: parallel scan functions

* feat: use parallel scan in certain conditions

* test: gradient using parallel scan
  • Loading branch information
yoyololicon committed Jul 10, 2024
1 parent 28f6b16 commit f66cc0c
Show file tree
Hide file tree
Showing 4 changed files with 256 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,3 +108,35 @@ def test_float64_vs_32_cuda():
assert torch.allclose(y64, y32.double(), atol=1e-6), torch.max(
torch.abs(y64 - y32.double())
)


@pytest.mark.parametrize(
"x_requires_grad",
[True],
)
@pytest.mark.parametrize(
"a_requires_grad",
[True, False],
)
@pytest.mark.parametrize(
"zi_requires_grad",
[True, False],
)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_cuda_parallel_scan(
x_requires_grad: bool,
a_requires_grad: bool,
zi_requires_grad: bool,
):
batch_size = 2
samples = 123
x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda")
A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1
zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda")

A.requires_grad = a_requires_grad
x.requires_grad = x_requires_grad
zi.requires_grad = zi_requires_grad

assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True)
assert gradgradcheck(LPC.apply, (x, A, zi))
5 changes: 5 additions & 0 deletions torchlpc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from typing import Optional

from .core import LPC
from .parallel_scan import WARPSIZE
from .recurrence import RecurrenceCUDA

__all__ = ["sample_wise_lpc"]

Expand Down Expand Up @@ -35,4 +37,7 @@ def sample_wise_lpc(
else:
assert zi.shape == (B, order)

if order == 1 and x.is_cuda and B * WARPSIZE < T:
return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1))

return LPC.apply(x, a, zi)
160 changes: 160 additions & 0 deletions torchlpc/parallel_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from numba import cuda

WARPSIZE = 32

# implementation was translated from https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.cu


@cuda.jit(device=True)
def divide_work(n_jobs, n_workers, worker_idx) -> tuple:
cd = (n_jobs + n_workers - 1) // n_workers
d, doing_cd = divmod(n_jobs, n_workers)
if worker_idx < doing_cd:
x = cd * worker_idx
y = x + cd
else:
x = cd * doing_cd + d * (worker_idx - doing_cd)
y = x + d
return x, y


@cuda.jit(device=True)
def compute_warp_start_stop(blockIdx, warp_idx, n_blocks, n_steps):
block_start, block_stop = divide_work(n_steps, n_blocks, blockIdx)
block_jobs = block_stop - block_start

warp_start, warp_stop = divide_work(block_jobs, WARPSIZE, warp_idx)
warp_start += block_start
warp_stop += block_start

return warp_start, warp_stop


@cuda.jit
def reduction_kernel(
decay, impulses, initial_state, decay_storage, h_storage, n_dims, n_steps
):
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE)

storage_offset = cuda.blockIdx.x * (WARPSIZE + 1)

warp_start, warp_stop = compute_warp_start_stop(
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps
)

# reduce within warp
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE):
cum_decay = 1.0
h = 0.0
if (cuda.blockIdx.x == 0) and (lane == 0):
h = initial_state[i]

for t in range(warp_start, warp_stop):
cum_decay *= decay[i, t]
h = decay[i, t] * h + impulses[i, t]

decay_storage[lane + storage_offset, i] = cum_decay
h_storage[lane + storage_offset, i] = h

cuda.syncthreads()

# reduce within block
for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x):
cum_decay = 1.0
h = 0.0
for t in range(storage_offset, storage_offset + WARPSIZE):
cum_decay *= decay_storage[t, i]
h = decay_storage[t, i] * h + h_storage[t, i]

decay_storage[WARPSIZE + storage_offset, i] = cum_decay
h_storage[WARPSIZE + storage_offset, i] = h


@cuda.jit
def block_scan_kernel(decay_storage, h_storage, n_dims, n_blocks):
for i in range(
cuda.grid(1),
n_dims,
cuda.gridsize(1),
):
for t in range(1, n_blocks):
cur_idx = t * (WARPSIZE + 1) + WARPSIZE
prev_idx = (t - 1) * (WARPSIZE + 1) + WARPSIZE
h_storage[cur_idx, i] += h_storage[prev_idx, i] * decay_storage[cur_idx, i]
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i]


@cuda.jit
def warp_scan_kernel(
decay, impulses, initial_state, out, decay_storage, h_storage, n_dims, n_steps
):
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE)

for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x):
offset = cuda.blockIdx.x * (WARPSIZE + 1)
for cur_idx in range(offset, offset + WARPSIZE):
if cur_idx == 0:
continue
prev_idx = cur_idx - 1
h_storage[cur_idx, i] = (
h_storage[prev_idx, i] * decay_storage[cur_idx, i]
+ h_storage[cur_idx, i]
)
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i]

cuda.syncthreads()

warp_start, warp_stop = compute_warp_start_stop(
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps
)

# scan within warp
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE):
if (cuda.blockIdx.x == 0) and (lane == 0):
h = initial_state[i]
else:
h = h_storage[lane - 1 + cuda.blockIdx.x * (WARPSIZE + 1), i]

for t in range(warp_start, warp_stop):
h = decay[i, t] * h + impulses[i, t]
out[i, t] = h


def compute_linear_recurrence(
decays, impulses, init_states, out, n_dims: int, n_steps: int
):
n_blocks = min((n_steps + WARPSIZE - 1) // WARPSIZE, 128)

reduction_mem_shape = (n_blocks * (WARPSIZE + 1), n_dims)
decay_storage = cuda.device_array(reduction_mem_shape, dtype=decays.dtype)
h_storage = cuda.device_array(reduction_mem_shape, dtype=impulses.dtype)

reduction_kernel[n_blocks, 512](
decays, impulses, init_states, decay_storage, h_storage, n_dims, n_steps
)

block_scan_kernel[n_blocks, 512](decay_storage, h_storage, n_dims, n_blocks)

warp_scan_kernel[n_blocks, 512](
decays, impulses, init_states, out, decay_storage, h_storage, n_dims, n_steps
)


if __name__ == "__main__":
import numpy as np

n_dims = 16
n_steps = 20480
decays = np.full((n_dims, n_steps), 0.9, dtype=np.float32)
impulses = np.full((n_dims, n_steps), 0.0, dtype=np.float32)
impulses[:, 0] = 1.0
init_states = np.full(n_dims, 0.0, dtype=np.float32)

decays = cuda.to_device(decays)
impulses = cuda.to_device(impulses)
init_states = cuda.to_device(init_states)
out = cuda.device_array((n_dims, n_steps), dtype=np.float32)

compute_linear_recurrence(decays, impulses, init_states, out, n_dims, n_steps)

print(out.copy_to_host())
59 changes: 59 additions & 0 deletions torchlpc/recurrence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import torch.nn.functional as F
from torch.autograd import Function
from numba import cuda

from .parallel_scan import compute_linear_recurrence


class RecurrenceCUDA(Function):
@staticmethod
def forward(
ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor
) -> torch.Tensor:
n_dims, n_steps = decay.shape
out = torch.empty_like(impulse)
compute_linear_recurrence(
cuda.as_cuda_array(decay.detach()),
cuda.as_cuda_array(impulse.detach()),
cuda.as_cuda_array(initial_state.detach()),
cuda.as_cuda_array(out),
n_dims,
n_steps,
)
ctx.save_for_backward(decay, initial_state, out)
return out

@staticmethod
def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor:
decay, initial_state, out = ctx.saved_tensors
grad_decay = grad_impulse = grad_initial_state = None
n_dims, _ = decay.shape

padded_decay = F.pad(decay.unsqueeze(1), (0, 1)).squeeze(1)
if ctx.needs_input_grad[2]:
padded_grad_out = F.pad(grad_out.unsqueeze(1), (1, 0)).squeeze(1)
else:
padded_grad_out = grad_out
padded_decay = padded_decay[:, 1:]

init = padded_grad_out.new_zeros(n_dims)
flipped_grad_impulse = RecurrenceCUDA.apply(
padded_decay.flip(1).conj_physical(),
padded_grad_out.flip(1),
init,
)

if ctx.needs_input_grad[2]:
grad_initial_state = flipped_grad_impulse[:, -1]
flipped_grad_impulse = flipped_grad_impulse[:, :-1]

if ctx.needs_input_grad[1]:
grad_impulse = flipped_grad_impulse.flip(1)

if ctx.needs_input_grad[0]:
valid_out = out[:, :-1]
padded_out = torch.cat([initial_state.unsqueeze(1), valid_out], dim=1)
grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1)

return grad_decay, grad_impulse, grad_initial_state

0 comments on commit f66cc0c

Please sign in to comment.