Skip to content
Permalink

Comparing changes

Choose two branches to see what’s changed or to start a new pull request. If you need to, you can also or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: unit8co/darts
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: master
Choose a base ref
...
head repository: spicehq/darts
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: master
Choose a head ref
Can’t automatically merge. Don’t worry, you can still create the pull request.
  • 7 commits
  • 8 files changed
  • 2 contributors

Commits on Feb 7, 2024

  1. ONNX exporting (#1)

    Jeadie authored Feb 7, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    1ce4767 View commit details

Commits on Feb 8, 2024

  1. Update core.txt (#2)

    * Update core.txt
    
    * Update core.txt
    Jeadie authored Feb 8, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    a184762 View commit details

Commits on Feb 14, 2024

  1. Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    213ba81 View commit details

Commits on Jun 12, 2024

  1. Update black -> >=24.3.0 (#5)

    Jeadie authored Jun 12, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    7050b1d View commit details
  2. Update Jinja2 -> >=3.1.4 (#6)

    Jeadie authored Jun 12, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    c1ee2a2 View commit details

Commits on Jul 3, 2024

  1. Update torch.txt (#7)

    y-f-u authored Jul 3, 2024

    Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    bcaf266 View commit details

Commits on Oct 15, 2024

  1. Verified

    This commit was created on GitHub.com and signed with GitHub’s verified signature.
    Copy the full SHA
    064f456 View commit details
29 changes: 28 additions & 1 deletion darts/models/forecasting/catboost_model.py
Original file line number Diff line number Diff line change
@@ -11,8 +11,10 @@

import numpy as np
from catboost import CatBoostRegressor
import onnxmltools
from onnxmltools.convert.common.data_types import FloatTensorType

from darts.logging import get_logger
from darts.logging import get_logger, raise_log
from darts.models.forecasting.regression_model import RegressionModel, _LikelihoodMixin
from darts.timeseries import TimeSeries

@@ -309,3 +311,28 @@ def min_train_series_length(self) -> int:
if "target" in self.lags
else self.output_chunk_length,
)

@property
def supports_exporting_to_onnx(self) -> bool:
return True

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().check_export_onnx(path, **onnx_kwargs)
if self.model is None:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)

if path is None:
path = f"{self._default_save_path()}.onnx"

# Jeadie: This doesn;t really work yet. Darts is doing something rather odd, so the catboost
# libraries own `.save_model` is returning approx. an empty catboost.
self.model.estimator.__setattr__('_random_seed', '42') # This may be a bug with Darts.
self.model.estimator.save_model(path, format="onnx") # estimator is underlying catboost model.
23 changes: 23 additions & 0 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
@@ -278,6 +278,13 @@ def supports_optimized_historical_forecasts(self) -> bool:
"""
return False

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return False

@property
def output_chunk_length(self) -> Optional[int]:
"""
@@ -1887,6 +1894,22 @@ def model_params(self) -> dict:
def _default_save_path(cls) -> str:
return f"{cls.__name__}_{datetime.datetime.now().strftime('%Y-%m-%d_%H_%M_%S')}"

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
self.check_export_onnx(path, onnx_kwargs=onnx_kwargs)

def check_export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
if not self.supports_exporting_to_onnx:
raise_log(
AssertionError(
f"Model '{path.__class__}' does not support exporting to ONNX."
),
logger=logger,
)

def save(
self, path: Optional[Union[str, os.PathLike, BinaryIO]] = None, **pkl_kwargs
) -> None:
29 changes: 29 additions & 0 deletions darts/models/forecasting/regression_model.py
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@
from typing_extensions import Literal

import numpy as np
import onnxmltools
import pandas as pd
from sklearn.linear_model import LinearRegression

@@ -469,6 +470,34 @@ def get_multioutput_estimator(self, horizon, target_dim):

return self.model.estimators_[horizon + target_dim]

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return True

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().check_export_onnx(path, **onnx_kwargs)
if self.model is None:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)

if path is None:
path = f"{self._default_save_path()}.onnx"

# TODO find and element initial_type, e.g. = [("float_input", FloatTensorType([None, 4]))]
onx = onnxmltools.convert_sklearn(self.model, initial_types=[("input", FloatTensorType([1, 30]))])
with open(path, "wb") as f:
f.write(onx.SerializeToString())

def _create_lagged_data(
self,
target_series: Sequence[TimeSeries],
49 changes: 49 additions & 0 deletions darts/models/forecasting/torch_forecasting_model.py
Original file line number Diff line number Diff line change
@@ -2016,6 +2016,55 @@ def _is_probabilistic(self) -> bool:
else True # all torch models can be probabilistic (via Dropout)
)

@property
def supports_exporting_to_onnx(self) -> bool:
"""
Whether the model supports exporting the model in ONNX format
"""
return True

def export_onnx(self, path: Optional[str] = None, **onnx_kwargs) -> None:
"""
Exports the model as an ONNX file.
"""
super().export_onnx(path, **onnx_kwargs)
if not self.model_created:
raise_log(
AssertionError(
f"Model '{path.__class__}' supports ONNX, but the model does not yet exist."
),
logger=logger,
)

if path is None:
path = f"{self._default_save_path()}.onnx"

# TODO: This only works for PastCovariatesModel so far
if not issubclass(self.__class__, PastCovariatesTorchModel):
raise_log(
AssertionError(
f"For TorchForeacstingModels, currently only PastCovariatesModel are supported."
),
logger=logger,
)

(
past_target,
past_covariates,
future_past_covariates,
static_covariates,
# I think these have to do with future covariates (which isn't supported in Dlinear)
) = [torch.Tensor(x).unsqueeze(0) if x is not None else None for x in self.train_sample]

input_past = torch.cat(
[ds for ds in [past_target, past_covariates] if ds is not None],
dim=2, # Shape is (1, lookback_size, no. of variates (in either target or series))
)

input_sample = [input_past.float(), static_covariates.float() if static_covariates is not None else None]
self.model.float().to_onnx(path, input_sample=input_sample, opset_version=17)
# self.model.to_onnx(path, torch.from_numpy(self.train_sample[0]), **onnx_kwargs)

def _check_optimizable_historical_forecasts(
self,
forecast_horizon: int,
2 changes: 2 additions & 0 deletions requirements/core.txt
Original file line number Diff line number Diff line change
@@ -3,6 +3,8 @@ joblib>=0.16.0
matplotlib>=3.3.0
nfoursid>=1.0.0
numpy>=1.19.0
onnxmltools>=1.12.0
onnxconverter-common
pandas>=1.0.5,<2.0.0; python_version < "3.9"
pandas>=1.0.5; python_version >= "3.9"
pmdarima>=1.8.0
2 changes: 1 addition & 1 deletion requirements/dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
black[jupyter]==22.3.0
black[jupyter]>=24.3.0
flake8==4.0.1
isort==5.11.5
pre-commit
4 changes: 2 additions & 2 deletions requirements/release.txt
Original file line number Diff line number Diff line change
@@ -3,9 +3,9 @@ docutils==0.17.1
ipython==8.10.0
ipykernel==5.3.4
ipywidgets==7.5.1
jupyterlab==4.0.11
jupyterlab>=4.2.5
ipython_genutils==0.2.0
jinja2==3.1.3
Jinja2>=3.1.4
m2r2==0.3.2
nbsphinx==0.8.7
numpydoc==1.1.0
2 changes: 1 addition & 1 deletion requirements/torch.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
pytorch-lightning>=1.5.0,<=2.1.2
pytorch-lightning>=1.5.0
tensorboardX>=2.1
torch>=1.8.0