Skip to content

Commit

Permalink
✨ support datetime64 y
Browse files Browse the repository at this point in the history
  • Loading branch information
jvdd committed Dec 19, 2022
1 parent 7064703 commit 6a0f92e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
7 changes: 6 additions & 1 deletion tests/test_config.py
Expand Up @@ -16,7 +16,12 @@
]

supported_dtypes_x = _core_supported_dtypes
supported_dtypes_y = _core_supported_dtypes + [np.int8, np.uint8, np.bool8]
supported_dtypes_y = _core_supported_dtypes + [
np.timedelta64,
np.int8,
np.uint8,
np.bool8,
]

_core_rust_primitive_types = [
"f16",
Expand Down
9 changes: 8 additions & 1 deletion tsdownsample/downsampling_interface.py
Expand Up @@ -124,7 +124,7 @@ def downsample(self, *args, n_out: int, **kwargs): # x and y are optional
]
# <= 8-bit x-dtypes are not supported as the range of the values is too small to require
# downsampling
_y_rust_dtypes = _rust_dtypes + ["int8", "uint8", "bool"]
_y_rust_dtypes = _rust_dtypes + ["timedelta64", "int8", "uint8", "bool"]


class AbstractRustDownsampler(AbstractDownsampler, ABC):
Expand Down Expand Up @@ -195,6 +195,7 @@ def _switch_mod_with_y(
elif y_dtype == np.int64:
return getattr(mod, downsample_func + "_i64")
# DATETIME -> i64 (datetime64 is viewed as int64)
# TIMEDELTA -> i64 (timedelta64 is viewed as int64)
# BOOLS -> int8 (bool is viewed as int8)
raise ValueError(f"Unsupported data type (for y): {y_dtype}")

Expand Down Expand Up @@ -279,18 +280,24 @@ def _downsample(
)
else:
mod = self.mod_multi_core
## Viewing the y-data as different dtype (if necessary)
if y.dtype == "bool":
# bool is viewed as int8
y = y.view(dtype=np.int8)
elif np.issubdtype(y.dtype, np.datetime64):
# datetime64 is viewed as int64
y = y.view(dtype=np.int64)
elif np.issubdtype(y.dtype, np.timedelta64):
# timedelta64 is viewed as int64
y = y.view(dtype=np.int64)
## Viewing the x-data as different dtype (if necessary)
if x is None:
downsample_f = self._switch_mod_with_y(y.dtype, mod)
return downsample_f(y, n_out, **kwargs)
elif np.issubdtype(x.dtype, np.datetime64):
# datetime64 is viewed as int64
x = x.view(dtype=np.int64)
## Getting the appropriate downsample function
downsample_f = self._switch_mod_with_x_and_y(x.dtype, y.dtype, mod)
return downsample_f(x, y, n_out, **kwargs)

Expand Down

0 comments on commit 6a0f92e

Please sign in to comment.