From f66cc0cf4eb34a9d333978d2fe063cc023b28ef0 Mon Sep 17 00:00:00 2001 From: Chin-Yun Yu Date: Wed, 10 Jul 2024 20:36:43 +0900 Subject: [PATCH] feat: parallel-scan algorithm for first-order filter (#11) * feat: parallel scan functions * feat: use parallel scan in certain conditions * test: gradient using parallel scan --- tests/test_grad.py | 32 ++++++++ torchlpc/__init__.py | 5 ++ torchlpc/parallel_scan.py | 160 ++++++++++++++++++++++++++++++++++++++ torchlpc/recurrence.py | 59 ++++++++++++++ 4 files changed, 256 insertions(+) create mode 100644 torchlpc/parallel_scan.py create mode 100644 torchlpc/recurrence.py diff --git a/tests/test_grad.py b/tests/test_grad.py index d9a7920..9913f4b 100644 --- a/tests/test_grad.py +++ b/tests/test_grad.py @@ -108,3 +108,35 @@ def test_float64_vs_32_cuda(): assert torch.allclose(y64, y32.double(), atol=1e-6), torch.max( torch.abs(y64 - y32.double()) ) + + +@pytest.mark.parametrize( + "x_requires_grad", + [True], +) +@pytest.mark.parametrize( + "a_requires_grad", + [True, False], +) +@pytest.mark.parametrize( + "zi_requires_grad", + [True, False], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_cuda_parallel_scan( + x_requires_grad: bool, + a_requires_grad: bool, + zi_requires_grad: bool, +): + batch_size = 2 + samples = 123 + x = torch.randn(batch_size, samples, dtype=torch.double, device="cuda") + A = torch.rand(batch_size, samples, 1, dtype=torch.double, device="cuda") * 2 - 1 + zi = torch.randn(batch_size, 1, dtype=torch.double, device="cuda") + + A.requires_grad = a_requires_grad + x.requires_grad = x_requires_grad + zi.requires_grad = zi_requires_grad + + assert gradcheck(LPC.apply, (x, A, zi), check_forward_ad=True) + assert gradgradcheck(LPC.apply, (x, A, zi)) diff --git a/torchlpc/__init__.py b/torchlpc/__init__.py index 44944b6..acecbed 100644 --- a/torchlpc/__init__.py +++ b/torchlpc/__init__.py @@ -2,6 +2,8 @@ from typing import Optional from .core import LPC +from .parallel_scan import WARPSIZE +from .recurrence import RecurrenceCUDA __all__ = ["sample_wise_lpc"] @@ -35,4 +37,7 @@ def sample_wise_lpc( else: assert zi.shape == (B, order) + if order == 1 and x.is_cuda and B * WARPSIZE < T: + return RecurrenceCUDA.apply(-a.squeeze(2), x, zi.squeeze(1)) + return LPC.apply(x, a, zi) diff --git a/torchlpc/parallel_scan.py b/torchlpc/parallel_scan.py new file mode 100644 index 0000000..fef3183 --- /dev/null +++ b/torchlpc/parallel_scan.py @@ -0,0 +1,160 @@ +from numba import cuda + +WARPSIZE = 32 + +# implementation was translated from https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.cu + + +@cuda.jit(device=True) +def divide_work(n_jobs, n_workers, worker_idx) -> tuple: + cd = (n_jobs + n_workers - 1) // n_workers + d, doing_cd = divmod(n_jobs, n_workers) + if worker_idx < doing_cd: + x = cd * worker_idx + y = x + cd + else: + x = cd * doing_cd + d * (worker_idx - doing_cd) + y = x + d + return x, y + + +@cuda.jit(device=True) +def compute_warp_start_stop(blockIdx, warp_idx, n_blocks, n_steps): + block_start, block_stop = divide_work(n_steps, n_blocks, blockIdx) + block_jobs = block_stop - block_start + + warp_start, warp_stop = divide_work(block_jobs, WARPSIZE, warp_idx) + warp_start += block_start + warp_stop += block_start + + return warp_start, warp_stop + + +@cuda.jit +def reduction_kernel( + decay, impulses, initial_state, decay_storage, h_storage, n_dims, n_steps +): + warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) + + storage_offset = cuda.blockIdx.x * (WARPSIZE + 1) + + warp_start, warp_stop = compute_warp_start_stop( + cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps + ) + + # reduce within warp + for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): + cum_decay = 1.0 + h = 0.0 + if (cuda.blockIdx.x == 0) and (lane == 0): + h = initial_state[i] + + for t in range(warp_start, warp_stop): + cum_decay *= decay[i, t] + h = decay[i, t] * h + impulses[i, t] + + decay_storage[lane + storage_offset, i] = cum_decay + h_storage[lane + storage_offset, i] = h + + cuda.syncthreads() + + # reduce within block + for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): + cum_decay = 1.0 + h = 0.0 + for t in range(storage_offset, storage_offset + WARPSIZE): + cum_decay *= decay_storage[t, i] + h = decay_storage[t, i] * h + h_storage[t, i] + + decay_storage[WARPSIZE + storage_offset, i] = cum_decay + h_storage[WARPSIZE + storage_offset, i] = h + + +@cuda.jit +def block_scan_kernel(decay_storage, h_storage, n_dims, n_blocks): + for i in range( + cuda.grid(1), + n_dims, + cuda.gridsize(1), + ): + for t in range(1, n_blocks): + cur_idx = t * (WARPSIZE + 1) + WARPSIZE + prev_idx = (t - 1) * (WARPSIZE + 1) + WARPSIZE + h_storage[cur_idx, i] += h_storage[prev_idx, i] * decay_storage[cur_idx, i] + decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] + + +@cuda.jit +def warp_scan_kernel( + decay, impulses, initial_state, out, decay_storage, h_storage, n_dims, n_steps +): + warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) + + for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): + offset = cuda.blockIdx.x * (WARPSIZE + 1) + for cur_idx in range(offset, offset + WARPSIZE): + if cur_idx == 0: + continue + prev_idx = cur_idx - 1 + h_storage[cur_idx, i] = ( + h_storage[prev_idx, i] * decay_storage[cur_idx, i] + + h_storage[cur_idx, i] + ) + decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] + + cuda.syncthreads() + + warp_start, warp_stop = compute_warp_start_stop( + cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps + ) + + # scan within warp + for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): + if (cuda.blockIdx.x == 0) and (lane == 0): + h = initial_state[i] + else: + h = h_storage[lane - 1 + cuda.blockIdx.x * (WARPSIZE + 1), i] + + for t in range(warp_start, warp_stop): + h = decay[i, t] * h + impulses[i, t] + out[i, t] = h + + +def compute_linear_recurrence( + decays, impulses, init_states, out, n_dims: int, n_steps: int +): + n_blocks = min((n_steps + WARPSIZE - 1) // WARPSIZE, 128) + + reduction_mem_shape = (n_blocks * (WARPSIZE + 1), n_dims) + decay_storage = cuda.device_array(reduction_mem_shape, dtype=decays.dtype) + h_storage = cuda.device_array(reduction_mem_shape, dtype=impulses.dtype) + + reduction_kernel[n_blocks, 512]( + decays, impulses, init_states, decay_storage, h_storage, n_dims, n_steps + ) + + block_scan_kernel[n_blocks, 512](decay_storage, h_storage, n_dims, n_blocks) + + warp_scan_kernel[n_blocks, 512]( + decays, impulses, init_states, out, decay_storage, h_storage, n_dims, n_steps + ) + + +if __name__ == "__main__": + import numpy as np + + n_dims = 16 + n_steps = 20480 + decays = np.full((n_dims, n_steps), 0.9, dtype=np.float32) + impulses = np.full((n_dims, n_steps), 0.0, dtype=np.float32) + impulses[:, 0] = 1.0 + init_states = np.full(n_dims, 0.0, dtype=np.float32) + + decays = cuda.to_device(decays) + impulses = cuda.to_device(impulses) + init_states = cuda.to_device(init_states) + out = cuda.device_array((n_dims, n_steps), dtype=np.float32) + + compute_linear_recurrence(decays, impulses, init_states, out, n_dims, n_steps) + + print(out.copy_to_host()) diff --git a/torchlpc/recurrence.py b/torchlpc/recurrence.py new file mode 100644 index 0000000..45bd4c6 --- /dev/null +++ b/torchlpc/recurrence.py @@ -0,0 +1,59 @@ +import torch +import torch.nn.functional as F +from torch.autograd import Function +from numba import cuda + +from .parallel_scan import compute_linear_recurrence + + +class RecurrenceCUDA(Function): + @staticmethod + def forward( + ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor + ) -> torch.Tensor: + n_dims, n_steps = decay.shape + out = torch.empty_like(impulse) + compute_linear_recurrence( + cuda.as_cuda_array(decay.detach()), + cuda.as_cuda_array(impulse.detach()), + cuda.as_cuda_array(initial_state.detach()), + cuda.as_cuda_array(out), + n_dims, + n_steps, + ) + ctx.save_for_backward(decay, initial_state, out) + return out + + @staticmethod + def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor: + decay, initial_state, out = ctx.saved_tensors + grad_decay = grad_impulse = grad_initial_state = None + n_dims, _ = decay.shape + + padded_decay = F.pad(decay.unsqueeze(1), (0, 1)).squeeze(1) + if ctx.needs_input_grad[2]: + padded_grad_out = F.pad(grad_out.unsqueeze(1), (1, 0)).squeeze(1) + else: + padded_grad_out = grad_out + padded_decay = padded_decay[:, 1:] + + init = padded_grad_out.new_zeros(n_dims) + flipped_grad_impulse = RecurrenceCUDA.apply( + padded_decay.flip(1).conj_physical(), + padded_grad_out.flip(1), + init, + ) + + if ctx.needs_input_grad[2]: + grad_initial_state = flipped_grad_impulse[:, -1] + flipped_grad_impulse = flipped_grad_impulse[:, :-1] + + if ctx.needs_input_grad[1]: + grad_impulse = flipped_grad_impulse.flip(1) + + if ctx.needs_input_grad[0]: + valid_out = out[:, :-1] + padded_out = torch.cat([initial_state.unsqueeze(1), valid_out], dim=1) + grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1) + + return grad_decay, grad_impulse, grad_initial_state