Skip to content

Commit

Permalink
reintegrate prophet in core deps (#1054)
Browse files Browse the repository at this point in the history
* reintegrate prophet in core deps

* less strict requirement on ipython

* rm pmdarima flavour; rm conditional tests

* remove flavours from build.gradle

* remove flavours from build.gradle

* remove pmdarima from setup.py
  • Loading branch information
hrzn committed Jul 12, 2022
1 parent 12c3702 commit 7ffa3fc
Show file tree
Hide file tree
Showing 11 changed files with 243 additions and 297 deletions.
6 changes: 3 additions & 3 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void createPipInstallTask(String flavour) {
}
}

String[] flavours = ["core", "prophet", "pmdarima", "dev", "torch", "release"];
String[] flavours = ["core", "dev", "torch", "release"];

for(String flavour : flavours) {
createPipInstallTask(flavour);
Expand All @@ -94,7 +94,7 @@ task pipInstall() {
doFirst {
setupPip
}
dependsOn pip_core, pip_dev, pip_pmdarima, pip_torch, pip_release
dependsOn pip_core, pip_dev, pip_torch, pip_release
}

task lint_black(type: Exec) {
Expand Down Expand Up @@ -134,7 +134,7 @@ void createPipRelatedTask(String flavour) {
}
}

flavours = ["core", "prophet", "pmdarima", "torch"];
flavours = ["core", "torch"];

for(String flavour : flavours) {
createPipRelatedTask(flavour);
Expand Down
24 changes: 3 additions & 21 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,41 +9,23 @@

# Forecasting
from darts.models.forecasting.arima import ARIMA
from darts.models.forecasting.auto_arima import AutoARIMA
from darts.models.forecasting.baselines import NaiveDrift, NaiveMean, NaiveSeasonal
from darts.models.forecasting.catboost_model import CatBoostModel
from darts.models.forecasting.croston import Croston
from darts.models.forecasting.exponential_smoothing import ExponentialSmoothing
from darts.models.forecasting.fft import FFT
from darts.models.forecasting.kalman_forecaster import KalmanForecaster
from darts.models.forecasting.linear_regression_model import LinearRegressionModel
from darts.models.forecasting.prophet_model import Prophet
from darts.models.forecasting.random_forest import RandomForest
from darts.models.forecasting.regression_ensemble_model import RegressionEnsembleModel
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.sf_auto_arima import StatsForecastAutoARIMA
from darts.models.forecasting.tbats import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.varima import VARIMA

try:
from darts.models.forecasting.auto_arima import AutoARIMA
from darts.models.forecasting.tbats import BATS, TBATS
except ModuleNotFoundError:
logger.warning(
"Support for AutoARIMA, BATS and TBATS is not available."
"To enable it, install u8darts[pmdarima] or u8darts[all]."
)

try:
from darts.models.forecasting.prophet_model import Prophet
except ModuleNotFoundError:
pass
"""
logger.warning(
"Support for Facebook Prophet is not available. "
"To enable it, install the prophet package in your environment; see "
"https://facebook.github.io/prophet/docs/installation.html"
)
"""

try:
from darts.models.forecasting.block_rnn_model import BlockRNNModel
from darts.models.forecasting.nbeats import NBEATSModel
Expand Down
34 changes: 13 additions & 21 deletions darts/tests/models/forecasting/test_local_forecasting_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@
from darts.metrics import mape
from darts.models import (
ARIMA,
BATS,
FFT,
TBATS,
VARIMA,
AutoARIMA,
Croston,
ExponentialSmoothing,
FourTheta,
KalmanForecaster,
NaiveSeasonal,
Prophet,
StatsForecastAutoARIMA,
Theta,
)
Expand Down Expand Up @@ -70,25 +74,14 @@
dual_models = [ARIMA(), StatsForecastAutoARIMA(period=12)]


try:
from darts.models import Prophet

models.append((Prophet(), 13.5))
dual_models.append(Prophet())
except ImportError:
logger.warning("Prophet not installed - will be skipping Prophet tests")
models.append((Prophet(), 13.5))
dual_models.append(Prophet())

try:
from darts.models import BATS, TBATS, AutoARIMA
models.append((AutoARIMA(), 12.2))
models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0))
models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0))
dual_models.append(AutoARIMA())

models.append((AutoARIMA(), 12.2))
models.append((TBATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 8.0))
models.append((BATS(use_trend=True, use_arma_errors=True, use_box_cox=True), 10.0))
dual_models.append(AutoARIMA())
PMDARIMA_AVAILABLE = True
except ImportError:
logger.warning("pmdarima not installed - will be skipping AutoARIMA tests")
PMDARIMA_AVAILABLE = False

try:
from darts.models import TCNModel # noqa: F401
Expand Down Expand Up @@ -228,7 +221,6 @@ def test_dummy_series(self):
with self.assertRaises(ValueError):
varima.fit(series=ts)

if PMDARIMA_AVAILABLE:
autoarima = AutoARIMA(trend="t")
with self.assertRaises(ValueError):
autoarima.fit(series=ts)
autoarima = AutoARIMA(trend="t")
with self.assertRaises(ValueError):
autoarima.fit(series=ts)
59 changes: 25 additions & 34 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,13 @@
from darts import TimeSeries
from darts.logging import get_logger
from darts.metrics import mae
from darts.models import ARIMA, ExponentialSmoothing
from darts.models import ARIMA, BATS, TBATS, ExponentialSmoothing
from darts.models.forecasting.forecasting_model import GlobalForecastingModel
from darts.tests.base_test_class import DartsBaseTestClass
from darts.utils import timeseries_generation as tg

logger = get_logger(__name__)

try:
from darts.models import BATS, TBATS

PMDARIMA_AVAILABLE = True
except ImportError:
logger.warning("pmdarima not available. BATS/TBATS probabilistic tests skipped.")
PMDARIMA_AVAILABLE = False

try:
import torch

Expand Down Expand Up @@ -58,31 +50,30 @@
(ARIMA, {"p": 1, "d": 0, "q": 1}, 0.03),
]

if PMDARIMA_AVAILABLE:
models_cls_kwargs_errs += [
(
BATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
(
TBATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
]
models_cls_kwargs_errs += [
(
BATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
(
TBATS,
{
"use_trend": False,
"use_damped_trend": False,
"use_box_cox": True,
"use_arma_errors": False,
"random_state": 42,
},
0.3,
),
]

if TORCH_AVAILABLE:
models_cls_kwargs_errs += [
Expand Down
Loading

0 comments on commit 7ffa3fc

Please sign in to comment.