-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: parallel-scan algorithm for first-order filter (#11)
* feat: parallel scan functions * feat: use parallel scan in certain conditions * test: gradient using parallel scan
- Loading branch information
1 parent
28f6b16
commit f66cc0c
Showing
4 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
from numba import cuda | ||
|
||
WARPSIZE = 32 | ||
|
||
# implementation was translated from https://github.com/eamartin/parallelizing_linear_rnns/blob/master/linear_recurrent_net/linear_recurrence.cu | ||
|
||
|
||
@cuda.jit(device=True) | ||
def divide_work(n_jobs, n_workers, worker_idx) -> tuple: | ||
cd = (n_jobs + n_workers - 1) // n_workers | ||
d, doing_cd = divmod(n_jobs, n_workers) | ||
if worker_idx < doing_cd: | ||
x = cd * worker_idx | ||
y = x + cd | ||
else: | ||
x = cd * doing_cd + d * (worker_idx - doing_cd) | ||
y = x + d | ||
return x, y | ||
|
||
|
||
@cuda.jit(device=True) | ||
def compute_warp_start_stop(blockIdx, warp_idx, n_blocks, n_steps): | ||
block_start, block_stop = divide_work(n_steps, n_blocks, blockIdx) | ||
block_jobs = block_stop - block_start | ||
|
||
warp_start, warp_stop = divide_work(block_jobs, WARPSIZE, warp_idx) | ||
warp_start += block_start | ||
warp_stop += block_start | ||
|
||
return warp_start, warp_stop | ||
|
||
|
||
@cuda.jit | ||
def reduction_kernel( | ||
decay, impulses, initial_state, decay_storage, h_storage, n_dims, n_steps | ||
): | ||
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) | ||
|
||
storage_offset = cuda.blockIdx.x * (WARPSIZE + 1) | ||
|
||
warp_start, warp_stop = compute_warp_start_stop( | ||
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps | ||
) | ||
|
||
# reduce within warp | ||
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): | ||
cum_decay = 1.0 | ||
h = 0.0 | ||
if (cuda.blockIdx.x == 0) and (lane == 0): | ||
h = initial_state[i] | ||
|
||
for t in range(warp_start, warp_stop): | ||
cum_decay *= decay[i, t] | ||
h = decay[i, t] * h + impulses[i, t] | ||
|
||
decay_storage[lane + storage_offset, i] = cum_decay | ||
h_storage[lane + storage_offset, i] = h | ||
|
||
cuda.syncthreads() | ||
|
||
# reduce within block | ||
for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): | ||
cum_decay = 1.0 | ||
h = 0.0 | ||
for t in range(storage_offset, storage_offset + WARPSIZE): | ||
cum_decay *= decay_storage[t, i] | ||
h = decay_storage[t, i] * h + h_storage[t, i] | ||
|
||
decay_storage[WARPSIZE + storage_offset, i] = cum_decay | ||
h_storage[WARPSIZE + storage_offset, i] = h | ||
|
||
|
||
@cuda.jit | ||
def block_scan_kernel(decay_storage, h_storage, n_dims, n_blocks): | ||
for i in range( | ||
cuda.grid(1), | ||
n_dims, | ||
cuda.gridsize(1), | ||
): | ||
for t in range(1, n_blocks): | ||
cur_idx = t * (WARPSIZE + 1) + WARPSIZE | ||
prev_idx = (t - 1) * (WARPSIZE + 1) + WARPSIZE | ||
h_storage[cur_idx, i] += h_storage[prev_idx, i] * decay_storage[cur_idx, i] | ||
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] | ||
|
||
|
||
@cuda.jit | ||
def warp_scan_kernel( | ||
decay, impulses, initial_state, out, decay_storage, h_storage, n_dims, n_steps | ||
): | ||
warp, lane = divmod(cuda.threadIdx.x, WARPSIZE) | ||
|
||
for i in range(cuda.threadIdx.x, n_dims, cuda.blockDim.x): | ||
offset = cuda.blockIdx.x * (WARPSIZE + 1) | ||
for cur_idx in range(offset, offset + WARPSIZE): | ||
if cur_idx == 0: | ||
continue | ||
prev_idx = cur_idx - 1 | ||
h_storage[cur_idx, i] = ( | ||
h_storage[prev_idx, i] * decay_storage[cur_idx, i] | ||
+ h_storage[cur_idx, i] | ||
) | ||
decay_storage[cur_idx, i] *= decay_storage[prev_idx, i] | ||
|
||
cuda.syncthreads() | ||
|
||
warp_start, warp_stop = compute_warp_start_stop( | ||
cuda.blockIdx.x, lane, cuda.gridDim.x, n_steps | ||
) | ||
|
||
# scan within warp | ||
for i in range(warp, n_dims, (cuda.blockDim.x + WARPSIZE - 1) // WARPSIZE): | ||
if (cuda.blockIdx.x == 0) and (lane == 0): | ||
h = initial_state[i] | ||
else: | ||
h = h_storage[lane - 1 + cuda.blockIdx.x * (WARPSIZE + 1), i] | ||
|
||
for t in range(warp_start, warp_stop): | ||
h = decay[i, t] * h + impulses[i, t] | ||
out[i, t] = h | ||
|
||
|
||
def compute_linear_recurrence( | ||
decays, impulses, init_states, out, n_dims: int, n_steps: int | ||
): | ||
n_blocks = min((n_steps + WARPSIZE - 1) // WARPSIZE, 128) | ||
|
||
reduction_mem_shape = (n_blocks * (WARPSIZE + 1), n_dims) | ||
decay_storage = cuda.device_array(reduction_mem_shape, dtype=decays.dtype) | ||
h_storage = cuda.device_array(reduction_mem_shape, dtype=impulses.dtype) | ||
|
||
reduction_kernel[n_blocks, 512]( | ||
decays, impulses, init_states, decay_storage, h_storage, n_dims, n_steps | ||
) | ||
|
||
block_scan_kernel[n_blocks, 512](decay_storage, h_storage, n_dims, n_blocks) | ||
|
||
warp_scan_kernel[n_blocks, 512]( | ||
decays, impulses, init_states, out, decay_storage, h_storage, n_dims, n_steps | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
import numpy as np | ||
|
||
n_dims = 16 | ||
n_steps = 20480 | ||
decays = np.full((n_dims, n_steps), 0.9, dtype=np.float32) | ||
impulses = np.full((n_dims, n_steps), 0.0, dtype=np.float32) | ||
impulses[:, 0] = 1.0 | ||
init_states = np.full(n_dims, 0.0, dtype=np.float32) | ||
|
||
decays = cuda.to_device(decays) | ||
impulses = cuda.to_device(impulses) | ||
init_states = cuda.to_device(init_states) | ||
out = cuda.device_array((n_dims, n_steps), dtype=np.float32) | ||
|
||
compute_linear_recurrence(decays, impulses, init_states, out, n_dims, n_steps) | ||
|
||
print(out.copy_to_host()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch.autograd import Function | ||
from numba import cuda | ||
|
||
from .parallel_scan import compute_linear_recurrence | ||
|
||
|
||
class RecurrenceCUDA(Function): | ||
@staticmethod | ||
def forward( | ||
ctx, decay: torch.Tensor, impulse: torch.Tensor, initial_state: torch.Tensor | ||
) -> torch.Tensor: | ||
n_dims, n_steps = decay.shape | ||
out = torch.empty_like(impulse) | ||
compute_linear_recurrence( | ||
cuda.as_cuda_array(decay.detach()), | ||
cuda.as_cuda_array(impulse.detach()), | ||
cuda.as_cuda_array(initial_state.detach()), | ||
cuda.as_cuda_array(out), | ||
n_dims, | ||
n_steps, | ||
) | ||
ctx.save_for_backward(decay, initial_state, out) | ||
return out | ||
|
||
@staticmethod | ||
def backward(ctx: torch.Any, grad_out: torch.Tensor) -> torch.Tensor: | ||
decay, initial_state, out = ctx.saved_tensors | ||
grad_decay = grad_impulse = grad_initial_state = None | ||
n_dims, _ = decay.shape | ||
|
||
padded_decay = F.pad(decay.unsqueeze(1), (0, 1)).squeeze(1) | ||
if ctx.needs_input_grad[2]: | ||
padded_grad_out = F.pad(grad_out.unsqueeze(1), (1, 0)).squeeze(1) | ||
else: | ||
padded_grad_out = grad_out | ||
padded_decay = padded_decay[:, 1:] | ||
|
||
init = padded_grad_out.new_zeros(n_dims) | ||
flipped_grad_impulse = RecurrenceCUDA.apply( | ||
padded_decay.flip(1).conj_physical(), | ||
padded_grad_out.flip(1), | ||
init, | ||
) | ||
|
||
if ctx.needs_input_grad[2]: | ||
grad_initial_state = flipped_grad_impulse[:, -1] | ||
flipped_grad_impulse = flipped_grad_impulse[:, :-1] | ||
|
||
if ctx.needs_input_grad[1]: | ||
grad_impulse = flipped_grad_impulse.flip(1) | ||
|
||
if ctx.needs_input_grad[0]: | ||
valid_out = out[:, :-1] | ||
padded_out = torch.cat([initial_state.unsqueeze(1), valid_out], dim=1) | ||
grad_decay = padded_out.conj_physical() * flipped_grad_impulse.flip(1) | ||
|
||
return grad_decay, grad_impulse, grad_initial_state |