Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support a user defined function name in the window transformation output #1676

Merged
Merged
8 changes: 8 additions & 0 deletions darts/dataprocessing/transformers/window_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def __init__(
transformation should be applied. If not specified, the transformation will be
applied on all components.

:``"function_name"``: Optional. A string specifying the function name referenced as part of
the transformation output name. For example, given a user-provided function
transformation on rolling window size of 5 on the component "comp", the
default transformation output name is "rolling_udf_5_comp" whereby "udf"
refers to "user defined function". If specified, the ``"function_name"`` will
replace the default name "udf". Similarly, the ``"function_name"`` will replace
the name of the pandas builtin transformation function name in the output name.

All other dictionary items provided will be treated as keyword arguments for the windowing mode
(i.e., ``rolling/ewm/expanding``) or for the specific function
in that mode (i.e., ``pandas.DataFrame.rolling.mean/std/max/min...`` or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@ def test_ts_windowtransf_output_series(self):
],
)

# test customized function name that overwrites the pandas builtin transformation
transforms = {
"function": "sum",
"mode": "rolling",
"window": 1,
"function_name": "customized_name",
}
transformed_ts = self.series_univ_det.window_transform(transforms=transforms)
self.assertEqual(
transformed_ts.components.to_list(),
[
f"{transforms['mode']}_{transforms['function_name']}_{str(transforms['window'])}_{comp}"
for comp in self.series_univ_det.components
],
)
del transforms["function_name"]

# multivariate deterministic input
# transform one component
transforms.update({"components": "0"})
Expand Down Expand Up @@ -242,6 +259,39 @@ def test_ts_windowtransf_output_series(self):
transformed_ts = self.series_multi_prob.window_transform(transforms=transforms)
self.assertEqual(transformed_ts.n_samples, 2)

def test_user_defined_function_behavior(self):
def count_above_mean(array):
mean = np.mean(array)
return np.where(array > mean)[0].size

transformation = {
"function": count_above_mean,
"mode": "rolling",
"window": 5,
}
transformed_ts = self.target.window_transform(
transformation,
)
expected_transformed_series = TimeSeries.from_times_and_values(
self.times,
[0, 1, 1, 2, 2, 2, 2, 2, 2, 2],
columns=["rolling_udf_5_0"],
)
self.assertEqual(transformed_ts, expected_transformed_series)

# test if a customized function name is provided
transformation.update({"function_name": "count_above_mean"})
transformed_ts = self.target.window_transform(
transformation,
)
self.assertEqual(
transformed_ts.components.to_list(),
[
f"{transformation['mode']}_{transformation['function_name']}_{str(transformation['window'])}_{comp}"
for comp in self.target.components
],
)

def test_ts_windowtransf_output_nabehavior(self):
window_transformations = {
"function": "sum",
Expand Down
16 changes: 15 additions & 1 deletion darts/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -3255,6 +3255,14 @@ def window_transform(
transformation should be applied. If not specified, the transformation will be
applied on all components.

:``"function_name"``: Optional. A string specifying the function name referenced as part of
the transformation output name. For example, given a user-provided function
transformation on rolling window size of 5 on the component "comp", the
default transformation output name is "rolling_udf_5_comp" whereby "udf"
refers to "user defined function". If specified, the ``"function_name"`` will
replace the default name "udf". Similarly, the ``"function_name"`` will replace
the name of the pandas builtin transformation function name in the output name.

All other dictionary items provided will be treated as keyword arguments for the windowing mode
(i.e., ``rolling/ewm/expanding``) or for the specific function
in that mode (i.e., ``pandas.DataFrame.rolling.mean/std/max/min...`` or
Expand Down Expand Up @@ -3409,6 +3417,7 @@ def _get_kwargs(transformation, forecasting_safe):
"function",
"group",
"components",
"function_name",
}

window_mode_expected_args = set(window_mode.__code__.co_varnames)
Expand Down Expand Up @@ -3536,8 +3545,13 @@ def _get_kwargs(transformation, forecasting_safe):
)
min_periods = transformation["min_periods"]
# set new columns names
fn_name = transformation.get("function_name")
if fn_name:
function_name = fn_name
else:
function_name = fn if fn != "apply" else "udf"
name_prefix = (
f"{window_mode}_{fn if fn != 'apply' else 'udf'}"
f"{window_mode}_{function_name}"
f"{'_'+str(transformation['window']) if 'window' in transformation else ''}"
f"{'_'+str(min_periods) if min_periods>1 else ''}"
)
Expand Down