Skip to content
This repository has been archived by the owner on Sep 25, 2023. It is now read-only.

Perf Improvements to SOS Filter #377

Merged
merged 5 commits into from May 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 23 additions & 37 deletions cpp/src/filtering/_sosfilt.cu
Expand Up @@ -28,20 +28,14 @@ __device__ void _cupy_sosfilt( const int n_signals,
T *s_buffer ) {

T *s_out { s_buffer };
T *s_zi { reinterpret_cast<T *>( &s_out[n_sections] ) };
T *s_sos { reinterpret_cast<T *>( &s_zi[n_sections * zi_width] ) };
T *s_sos { reinterpret_cast<T *>( &s_out[n_sections] ) };

const int tx { static_cast<int>( threadIdx.x ) };
const int ty { static_cast<int>( blockIdx.y * blockDim.y + threadIdx.y ) };
const int bx { static_cast<int>( blockIdx.x ) };

// Reset shared memory
s_out[tx] = 0;

// Load zi
for ( int i = 0; i < zi_width; i++ ) {
s_zi[tx * zi_width + i] = zi[ty * n_sections * zi_width + tx * zi_width + i];
}

// Load SOS
// b is in s_sos[tx * sos_width + [0-2]]
// a is in s_sos[tx * sos_width + [3-5]]
Expand All @@ -50,81 +44,73 @@ __device__ void _cupy_sosfilt( const int n_signals,
s_sos[tx * sos_width + i] = sos[tx * sos_width + i];
}

__syncthreads( );
// __syncthreads( );

T zi0 = zi[bx * n_sections * zi_width + tx * zi_width + 0];
T zi1 = zi[bx * n_sections * zi_width + tx * zi_width + 1];

const int load_size { n_sections - 1 };
const int unload_size { n_samples - load_size };

T temp {};
T x_n {};

if ( ty < n_signals ) {
if ( bx < n_signals ) {
// Loading phase
for ( int n = 0; n < load_size; n++ ) {
__syncthreads( );
if ( tx == 0 ) {
x_n = x_in[ty * n_samples + n];
x_n = x_in[bx * n_samples + n];
} else {
x_n = s_out[tx - 1];
}

// Use direct II transposed structure
temp = s_sos[tx * sos_width + 0] * x_n + s_zi[tx * zi_width + 0];

s_zi[tx * zi_width + 0] =
s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + s_zi[tx * zi_width + 1];

s_zi[tx * zi_width + 1] = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;
temp = s_sos[tx * sos_width + 0] * x_n + zi0;
zi0 = s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + zi1;
zi1 = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;

s_out[tx] = temp;

__syncthreads( );
}

// Processing phase
for ( int n = load_size; n < n_samples; n++ ) {
__syncthreads( );
if ( tx == 0 ) {
x_n = x_in[ty * n_samples + n];
x_n = x_in[bx * n_samples + n];
} else {
x_n = s_out[tx - 1];
}

// Use direct II transposed structure
temp = s_sos[tx * sos_width + 0] * x_n + s_zi[tx * zi_width + 0];

s_zi[tx * zi_width + 0] =
s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + s_zi[tx * zi_width + 1];

s_zi[tx * zi_width + 1] = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;
temp = s_sos[tx * sos_width + 0] * x_n + zi0;
zi0 = s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + zi1;
zi1 = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;

if ( tx < load_size ) {
s_out[tx] = temp;
} else {
x_in[ty * n_samples + ( n - load_size )] = temp;
x_in[bx * n_samples + ( n - load_size )] = temp;
}

__syncthreads( );
}

// Unloading phase
for ( int n = 0; n < n_sections; n++ ) {
__syncthreads( );
// retire threads that are less than n
if ( tx > n ) {
x_n = s_out[tx - 1];

// Use direct II transposed structure
temp = s_sos[tx * sos_width + 0] * x_n + s_zi[tx * zi_width + 0];

s_zi[tx * zi_width + 0] =
s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + s_zi[tx * zi_width + 1];

s_zi[tx * zi_width + 1] = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;
temp = s_sos[tx * sos_width + 0] * x_n + zi0;
zi0 = s_sos[tx * sos_width + 1] * x_n - s_sos[tx * sos_width + 4] * temp + zi1;
zi1 = s_sos[tx * sos_width + 2] * x_n - s_sos[tx * sos_width + 5] * temp;

if ( tx < load_size ) {
s_out[tx] = temp;
} else {
x_in[ty * n_samples + ( n + unload_size )] = temp;
x_in[bx * n_samples + ( n + unload_size )] = temp;
}
__syncthreads( );
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions python/cusignal/filtering/_sosfilt_cuda.py
Expand Up @@ -77,18 +77,17 @@ def _get_backend_kernel(dtype, grid, block, smem, k_type):

def _sosfilt(sos, x, zi):

threadsperblock = (sos.shape[0], 1) # Up-to (1024, 1) = 1024 max per block
blockspergrid = (1, x.shape[0])
threadsperblock = sos.shape[0] # Up-to (1024, 1) = 1024 max per block
blockspergrid = x.shape[0]

k_type = "sosfilt"

_populate_kernel_cache(x.dtype, k_type)

out_size = threadsperblock[0]
z_size = zi.shape[1] * zi.shape[2]
out_size = threadsperblock
sos_size = sos.shape[0] * sos.shape[1]

shared_mem = (out_size + z_size + sos_size) * x.dtype.itemsize
shared_mem = (out_size + sos_size) * x.dtype.itemsize

kernel = _get_backend_kernel(
x.dtype,
Expand All @@ -97,6 +96,7 @@ def _sosfilt(sos, x, zi):
shared_mem,
k_type,
)
print(zi.shape)

kernel(sos, x, zi)

Expand Down
64 changes: 39 additions & 25 deletions python/cusignal/filtering/filtering.py
Expand Up @@ -142,18 +142,18 @@ def firfilter(b, x, axis=-1, zi=None):
"""
b = cp.asarray(b)
if b.ndim != 1:
raise ValueError('object of too small depth for desired array')
raise ValueError("object of too small depth for desired array")

if x.ndim == 0:
raise ValueError('x must be at least 1-D')
raise ValueError("x must be at least 1-D")

inputs = [b, x]
if zi is not None:
# _linear_filter does not broadcast zi, but does do expansion of
# singleton dims.
zi = cp.asarray(zi)
if zi.ndim != x.ndim:
raise ValueError('object of too small depth for desired array')
raise ValueError("object of too small depth for desired array")
expected_shape = list(x.shape)
expected_shape[axis] = b.shape[0] - 1
expected_shape = tuple(expected_shape)
Expand All @@ -170,15 +170,15 @@ def firfilter(b, x, axis=-1, zi=None):
elif k != axis and zi.shape[k] == 1:
strides[k] = 0
else:
raise ValueError('Unexpected shape for zi: expected '
'%s, found %s.' %
(expected_shape, zi.shape))
zi = cp.lib.stride_tricks.as_strided(zi, expected_shape,
strides)
raise ValueError(
"Unexpected shape for zi: expected "
"%s, found %s." % (expected_shape, zi.shape)
)
zi = cp.lib.stride_tricks.as_strided(zi, expected_shape, strides)
inputs.append(zi)
dtype = cp.result_type(*inputs)

if dtype.char not in 'fdgFDGO':
if dtype.char not in "fdgFDGO":
raise NotImplementedError("input type '%s' not supported" % dtype)

b = cp.array(b, dtype=dtype)
Expand Down Expand Up @@ -363,10 +363,14 @@ def firfilter_zi(b):

def _validate_pad(padtype, padlen, x, axis, ntaps):
"""Helper to validate padding for filtfilt"""
if padtype not in ['even', 'odd', 'constant', None]:
raise ValueError(("Unknown value '%s' given to padtype. padtype "
"must be 'even', 'odd', 'constant', or None.") %
padtype)
if padtype not in ["even", "odd", "constant", None]:
raise ValueError(
(
"Unknown value '%s' given to padtype. padtype "
"must be 'even', 'odd', 'constant', or None."
)
% padtype
)

if padtype is None:
padlen = 0
Expand All @@ -379,15 +383,17 @@ def _validate_pad(padtype, padlen, x, axis, ntaps):

# x's 'axis' dimension must be bigger than edge.
if x.shape[axis] <= edge:
raise ValueError("The length of the input vector x must be greater "
"than padlen, which is %d." % edge)
raise ValueError(
"The length of the input vector x must be greater "
"than padlen, which is %d." % edge
)

if padtype is not None and edge > 0:
# Make an extension of length `edge` at each
# end of the input array.
if padtype == 'even':
if padtype == "even":
ext = _even_ext(x, edge, axis=axis)
elif padtype == 'odd':
elif padtype == "odd":
ext = _odd_ext(x, edge, axis=axis)
else:
ext = _const_ext(x, edge, axis=axis)
Expand All @@ -396,8 +402,9 @@ def _validate_pad(padtype, padlen, x, axis, ntaps):
return edge, ext


def firfilter2(b, x, axis=-1, padtype='odd', padlen=None, method='pad',
irlen=None):
def firfilter2(
b, x, axis=-1, padtype="odd", padlen=None, method="pad", irlen=None
):
"""
Apply a digital filter forward and backward to a signal.
This function applies a linear digital filter twice, once forward and
Expand Down Expand Up @@ -502,8 +509,9 @@ def firfilter2(b, x, axis=-1, padtype='odd', padlen=None, method='pad',
return cp.copy(y)


def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
irlen=None):
def filtfilt(
b, a, x, axis=-1, padtype="odd", padlen=None, method="pad", irlen=None
):
"""
Apply a digital filter forward and backward to a signal.
This function applies a linear digital filter twice, once forward and
Expand Down Expand Up @@ -572,8 +580,15 @@ def filtfilt(b, a, x, axis=-1, padtype='odd', padlen=None, method='pad',
"""
a = cp.atleast_1d(a)
if len(a) == 1:
return firfilter2(b, x, axis=axis, padtype=padtype, padlen=padlen,
method=method, irlen=irlen)
return firfilter2(
b,
x,
axis=axis,
padtype=padtype,
padlen=padlen,
method=method,
irlen=irlen,
)
else:
raise NotImplementedError("IIR support isn't supported yet")

Expand Down Expand Up @@ -704,9 +719,8 @@ def sosfilt(

# Determine how much shared memory is needed
out_size = sos.shape[0]
z_size = zi.shape[1] * zi.shape[2]
sos_size = sos.shape[0] * sos.shape[1]
shared_mem = (out_size + z_size + sos_size) * x.dtype.itemsize
shared_mem = (out_size + sos_size) * x.dtype.itemsize

if shared_mem > max_smem:
max_sections = (
Expand Down