Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/no lambda function in add_encoders #1957

Merged
merged 9 commits into from
Aug 31, 2023
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
- Fixed a bug in `RegressionEnsembleModel.extreme_lags` when the forecasting models have only covariates lags. [#1942](https://github.com/unit8co/darts/pull/1942) by [Antoine Madrona](https://github.com/madtoinou).
- Fixed a bug when using `TFTExplainer` with a `TFTModel` running on GPU. [#1949](https://github.com/unit8co/darts/pull/1949) by [Dennis Bader](https://github.com/dennisbader).
- Fixed a bug in `TorchForecastingModel.load_weights()` that raised an error when loading the weights from a valid architecture. [#1952](https://github.com/unit8co/darts/pull/1952) by [Antoine Madrona](https://github.com/madtoinou).
- 🔴 Dropped support for lambda functions in `add_encoders`’s “custom” encoder in favor of named functions to ensure that models can be exported. [#1957](https://github.com/unit8co/darts/pull/1957) by [Antoine Madrona]

### For developers of the library:

Expand Down
19 changes: 16 additions & 3 deletions darts/dataprocessing/encoders/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,7 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
ValueError
1) if the outermost key is other than (`past`, `future`)
2) if the innermost values are other than type `str` or `Sequence`
3) if any of entry in the innermost values is a lambda function
"""
if not params:
return [], []
Expand All @@ -1377,9 +1378,7 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
logger,
)

encoders = {
enc: params.get(enc, None) for enc in ENCODER_KEYS if params.get(enc, None)
}
encoders = {enc: params[enc] for enc in ENCODER_KEYS if params.get(enc, None)}

# check input for invalid temporal types
invalid_time_params = list()
Expand All @@ -1395,6 +1394,8 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
logger,
)

# check that encoders are not lambda functions (not pickable)
lambda_func_encoders = set()
# convert into tuples of (encoder string identifier, encoder attribute)
past_encoders, future_encoders = list(), list()
for enc, enc_params in encoders.items():
Expand All @@ -1414,6 +1415,18 @@ def _process_input_encoders(self, params: Dict) -> Tuple[List, List]:
else:
future_encoders.append((encoder_id, attr))

if isinstance(attr, Callable) and attr.__name__ == "<lambda>":
lambda_func_encoders.add(enc)

raise_if(
len(lambda_func_encoders) > 0,
f"Encountered lambda function in the following `add_encoders` entries : {lambda_func_encoders} "
f"at model creation. "
f"In order to prevent issues when saving the model, these encoders must be converted to "
f"named functions.",
logger,
)

for temp_enc, takes_temp, temp in [
(past_encoders, self.takes_past_covariates, "past"),
(future_encoders, self.takes_future_covariates, "future"),
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/dlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,11 +335,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/lgbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/linear_regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/nbeats.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,11 +656,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/nhits.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,11 +592,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/nlinear.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,11 +286,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/random_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
10 changes: 8 additions & 2 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down Expand Up @@ -1334,11 +1337,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,11 +357,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/tft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,11 +804,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/tide_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/transformer_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,11 +445,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down
5 changes: 4 additions & 1 deletion darts/models/forecasting/xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,14 @@ def __init__(
.. highlight:: python
.. code-block:: python

def encode_year(idx):
return (idx.year - 1950) / 50

add_encoders={
'cyclic': {'future': ['month']},
'datetime_attribute': {'future': ['hour', 'dayofweek']},
'position': {'past': ['relative'], 'future': ['relative']},
'custom': {'past': [lambda idx: (idx.year - 1950) / 50]},
'custom': {'past': [encode_year]},
'transformer': Scaler()
}
..
Expand Down