Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/no lambda function in add_encoders #1957

Merged
merged 9 commits into from
Aug 31, 2023
19 changes: 16 additions & 3 deletions darts/dataprocessing/encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,7 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
ValueError
1) if the outermost key is other than (`past`, `future`)
2) if the innermost values are other than type `str` or `Sequence`
3) if any of entry in the innermost values is a lambda function
"""
if not params:
return [], []
Expand All @@ -1377,9 +1378,7 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
logger,
)

encoders = {
enc: params.get(enc, None) for enc in ENCODER_KEYS if params.get(enc, None)
}
encoders = {enc: params[enc] for enc in ENCODER_KEYS if params.get(enc, None)}

# check input for invalid temporal types
invalid_time_params = list()
Expand All @@ -1395,6 +1394,8 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
logger,
)

# check that encoders are not lambda functions (not pickable)
lambda_func_encoders = set()
# convert into tuples of (encoder string identifier, encoder attribute)
past_encoders, future_encoders = list(), list()
for enc, enc_params in encoders.items():
Expand All @@ -1414,6 +1415,18 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
else:
future_encoders.append((encoder_id, attr))

if isinstance(attr, Callable) and attr.__name__ == "<lambda>":
lambda_func_encoders.add(enc)

raise_if(
len(lambda_func_encoders) > 0,
f"Encountered lambda function in the following `add_encoders` entries : {lambda_func_encoders} "
f"at model creation. "
f"In order to prevent issues when saving the model, these encoders must be converted to "
f"named functions.",
logger,
)

for temp_enc, takes_temp, temp in [
(past_encoders, self.takes_past_covariates, "past"),
(future_encoders, self.takes_future_covariates, "future"),
Expand Down
37 changes: 25 additions & 12 deletions darts/tests/dataprocessing/encoders/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,15 @@ def test_sequence_encoder_from_model_params(self):
]

valid_encoder_args = {"cyclic": {"past": ["month"]}}
encoders = self.helper_encoder_from_model(
add_encoder_dict=valid_encoder_args, takes_future_covariates=False
)
encoders = self.helper_encoder_from_model(add_encoder_dict=valid_encoder_args)
assert len(encoders.past_encoders) == 1
assert len(encoders.future_encoders) == 0

# test invalid encoder kwarg at model creation
# test invalid encoder kwargs at model creation
bad_encoder = {"no_encoder": {"past": ["month"]}}
with pytest.raises(ValueError):
_ = self.helper_encoder_from_model(add_encoder_dict=bad_encoder)

# test invalid kwargs at model creation
bad_time = {"cyclic": {"ppast": ["month"]}}
with pytest.raises(ValueError):
_ = self.helper_encoder_from_model(add_encoder_dict=bad_time)
Expand All @@ -163,6 +160,10 @@ def test_sequence_encoder_from_model_params(self):
with pytest.raises(ValueError):
_ = self.helper_encoder_from_model(add_encoder_dict=bad_type)

bad_callable = {"custom": {"past": [lambda idx: idx.month]}}
with pytest.raises(ValueError):
_ = self.helper_encoder_from_model(add_encoder_dict=bad_callable)

@pytest.mark.skipif(not TORCH_AVAILABLE, reason="requires torch")
def test_encoder_sequence_train(self):
"""Test `SequentialEncoder.encode_train()` output"""
Expand Down Expand Up @@ -291,9 +292,7 @@ def helper_sequence_encode_inference(
for fc, fc_in in zip(future_covs_pred, expected_future_idx_ts):
assert fc.time_index.equals(fc_in.time_index)

def helper_encoder_from_model(
self, add_encoder_dict, takes_past_covariates=True, takes_future_covariates=True
):
def helper_encoder_from_model(self, add_encoder_dict):
"""extracts encoders from parameters at model creation"""
model = TFTModel(
input_chunk_length=self.input_chunk_length,
Expand Down Expand Up @@ -467,6 +466,13 @@ def test_sequential_encoder_general(self):
ts = tg.linear_timeseries(length=24, freq="MS")
covs = tg.linear_timeseries(length=24, freq="MS")

# encoders must be named function for pickling
def extract_month(index):
return index.month

def extract_year(index):
return index.year

input_chunk_length = 12
output_chunk_length = 6
add_encoders = {
Expand All @@ -480,8 +486,8 @@ def test_sequential_encoder_general(self):
"future": ["relative"],
},
"custom": {
"past": [lambda idx: idx.month, lambda idx: idx.year],
"future": [lambda idx: idx.month, lambda idx: idx.year],
"past": [extract_month, extract_year],
"future": [extract_month, extract_year],
},
"transformer": Scaler(),
}
Expand Down Expand Up @@ -887,11 +893,18 @@ def test_callable_encoder(self):
input_chunk_length = 12
output_chunk_length = 6

# encoders must be named functions for pickling
def index_year(index):
return index.year

def index_year_shifted(index):
return index.year - 1

# ===> test callable index encoder <===
encoder_params = {
"custom": {
"past": [lambda index: index.year, lambda index: index.year - 1],
"future": [lambda index: index.year],
"past": [index_year, index_year_shifted],
"future": [index_year],
}
}
encs = SequentialEncoder(
Expand Down
7 changes: 6 additions & 1 deletion darts/tests/explainability/test_shap_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@
cb_available = not isinstance(CatBoostModel, NotImportedModule)


def extract_year(index):
"""Return year of time index entry, normalized"""
return (index.year - 1950) / 50


class TestShapExplainer:
np.random.seed(42)

Expand All @@ -37,7 +42,7 @@ class TestShapExplainer:
"cyclic": {"past": ["month", "day"]},
"datetime_attribute": {"future": ["hour", "dayofweek"]},
"position": {"past": ["relative"], "future": ["relative"]},
"custom": {"past": [lambda idx: (idx.year - 1950) / 50]},
"custom": {"past": [extract_year]},
"transformer": Scaler(scaler),
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,12 @@ def test_encoders_support(self, config):
n = 3

target = self.ts_gaussian[:-3]
add_encoders = {"custom": {"future": [lambda x: x.dayofweek]}}

# encoder must be named function for pickling
def extract_dayofweek(index):
return index.dayofweek

add_encoders = {"custom": {"future": [extract_dayofweek]}}

series = (
target
Expand Down
33 changes: 13 additions & 20 deletions examples/00-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,7 @@
"<Figure size 720x360 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -900,9 +898,7 @@
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -954,9 +950,7 @@
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -1111,9 +1105,7 @@
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -1150,9 +1142,7 @@
"<Figure size 576x432 with 3 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -1478,9 +1468,7 @@
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"metadata": {},
"output_type": "display_data"
}
],
Expand Down Expand Up @@ -1692,11 +1680,16 @@
"metadata": {},
"outputs": [],
"source": [
"def extract_year(idx):\n",
" \"\"\"Extract the year each time index entry and normalized it.\"\"\"\n",
" return (idx.year - 1950) / 50\n",
"\n",
"\n",
"encoders = {\n",
" \"cyclic\": {\"future\": [\"month\"]},\n",
" \"datetime_attribute\": {\"future\": [\"hour\", \"dayofweek\"]},\n",
" \"position\": {\"past\": [\"absolute\"], \"future\": [\"relative\"]},\n",
" \"custom\": {\"past\": [lambda idx: (idx.year - 1950) / 50]},\n",
" \"custom\": {\"past\": [extract_year]},\n",
" \"transformer\": Scaler(),\n",
"}"
]
Expand All @@ -1715,7 +1708,7 @@
"* An additional custom function of the year should be used as past covariates.\n",
"* All the above covariates should be scaled using a `Scaler`, which will be fit upon calling the model `fit()` function and used afterwards to transform the covariates.\n",
"\n",
"We refer to [the API doc](https://unit8co.github.io/darts/generated_api/darts.utils.data.encoders.html#darts.utils.data.encoders.SequentialEncoder) for more informations about how to use encoders.\n",
"We refer to [the API doc](https://unit8co.github.io/darts/generated_api/darts.utils.data.encoders.html#darts.utils.data.encoders.SequentialEncoder) for more informations about how to use encoders. Note that lambda functions cannot be used as they are not pickable.\n",
"\n",
"To replicate our example with month and year used as past covariates with N-BEATS, we can use some encoders as follows:"
]
Expand Down