Skip to content

Commit

Permalink
fix: bad file descriptor (#769)
Browse files Browse the repository at this point in the history
* fix: bad file descriptor

When using Prophet under a multi-threading environment such as Dask, the error `[Errno 9] Bad File Descriptor` could happen. I can consistently reproduce this error while running forecasts with Dask. This implementation seems to fix the issue.

I do not fully understand why this happens, because only when running Dask for a long time would this happen. My implementation is semantically equivalent to the original, but I guess avoiding to open `devnull` too many times can help.

Context:
I am running forecast using Prophet for about 500_000 time series, all in 1 big Dask data frames (split into 91 smaller Pandas data frames).
If I don't use Dask, and instead use *joblib* to parallelize the forecasting, then this error does not happen. I think because joblib creates a separate thread for every worker, whereas Dask can have multiple threads for a worker.

* revert changes

Update logging.py

* make log suppression optional

* add unit test

* provide better data for test, update model dep

* restore original intention of model construction

* test case to make sure Prophet model is loaded by default

* reformat to black style

* sort imports

* underscore private variables

Co-authored-by: Julien Herzen <julien@unit8.co>
  • Loading branch information
khanetor and hrzn committed Feb 25, 2022
1 parent c2123c9 commit cbe49bf
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
6 changes: 3 additions & 3 deletions darts/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,10 @@ def timed(*args, **kwargs):

class SuppressStdoutStderr:
"""
A context manager for doing a "deep suppression" of stdout and stderr in
Python, i.e. will suppress all print, even if the print originates in a
A context manager for "deep suppression" of stdout and stderr in
Python, i.e. it suppresses all print, even if the print originates in a
compiled C/Fortran sub-function.
This will not suppress raised exceptions, since exceptions are printed
This does not suppress raised exceptions, since exceptions are printed
to stderr just before a script exits, and after the context manager has
exited (at least, I think that is why it lets exceptions through).
Expand Down
16 changes: 14 additions & 2 deletions darts/models/forecasting/prophet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def __init__(
self,
add_seasonalities: Optional[Union[dict, List[dict]]] = None,
country_holidays: Optional[str] = None,
suppress_stdout_stderror: bool = True,
**prophet_kwargs,
):
"""Facebook Prophet
Expand Down Expand Up @@ -66,6 +67,8 @@ def __init__(
countries: Brazil (BR), Indonesia (ID), India (IN), Malaysia (MY), Vietnam (VN),
Thailand (TH), Philippines (PH), Turkey (TU), Pakistan (PK), Bangladesh (BD),
Egypt (EG), China (CN), and Russia (RU).
suppress_stdout_stderror
Optionally suppress the log output produced by Prophet during training.
prophet_kwargs
Some optional keyword arguments for Prophet.
For information about the parameters see:
Expand All @@ -88,6 +91,10 @@ def __init__(
self.country_holidays = country_holidays
self.prophet_kwargs = prophet_kwargs
self.model = None
self.suppress_stdout_stderr = suppress_stdout_stderror

self._execute_and_suppress_output = execute_and_suppress_output
self._model_builder = prophet.Prophet

def __str__(self):
return "Prophet"
Expand All @@ -101,7 +108,7 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
data={"ds": series.time_index, "y": series.univariate_values()}
)

self.model = prophet.Prophet(**self.prophet_kwargs)
self.model = self._model_builder(**self.prophet_kwargs)

# add user defined seasonalities (from model creation and/or pre-fit self.add_seasonalities())
interval_length = self._freq_to_days(series.freq_str)
Expand All @@ -127,7 +134,12 @@ def _fit(self, series: TimeSeries, future_covariates: Optional[TimeSeries] = Non
if self.country_holidays is not None:
self.model.add_country_holidays(self.country_holidays)

execute_and_suppress_output(self.model.fit, logger, logging.WARNING, fit_df)
if self.suppress_stdout_stderr:
self._execute_and_suppress_output(
self.model.fit, logger, logging.WARNING, fit_df
)
else:
self.model.fit(fit_df)

return self

Expand Down
44 changes: 44 additions & 0 deletions darts/tests/models/forecasting/test_prophet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from unittest.mock import Mock

import numpy as np
import pandas as pd

from darts import TimeSeries
from darts.logging import get_logger
from darts.tests.base_test_class import DartsBaseTestClass
Expand Down Expand Up @@ -102,6 +107,45 @@ def test_prophet_model(self):
period=period, freq=freq, compare_all_models=False
)

def test_prophet_model_without_stdout_suppression(self):
model = Prophet(suppress_stdout_stderror=False)
model._execute_and_suppress_output = Mock(return_value=True)
model._model_builder = Mock(return_value=Mock(fit=Mock(return_value=True)))
df = pd.DataFrame(
{
"ds": pd.date_range(start="2022-01-01", periods=30, freq="D"),
"y": np.linspace(0, 10, 30),
}
)
ts = TimeSeries.from_dataframe(df, time_col="ds", value_cols="y")
model.fit(ts)

model._execute_and_suppress_output.assert_not_called(), "Suppression should not be called"
model.model.fit.assert_called_once(), "Model should still be fitted"

def test_prophet_model_with_stdout_suppression(self):
model = Prophet(suppress_stdout_stderror=True)
model._execute_and_suppress_output = Mock(return_value=True)
model._model_builder = Mock(return_value=Mock(fit=Mock(return_value=True)))
df = pd.DataFrame(
{
"ds": pd.date_range(start="2022-01-01", periods=30, freq="D"),
"y": np.linspace(0, 10, 30),
}
)
ts = TimeSeries.from_dataframe(df, time_col="ds", value_cols="y")
model.fit(ts)

model._execute_and_suppress_output.assert_called_once(), "Suppression should be called once"

def test_prophet_model_default_with_prophet_constructor(self):
from prophet import Prophet as FBProphet

model = Prophet()
assert (
model._model_builder == FBProphet
), "model should use Facebook Prophet"

def helper_test_freq_coversion(self, test_cases):
for freq, period in test_cases.items():
ts_sine = tg.sine_timeseries(
Expand Down

0 comments on commit cbe49bf

Please sign in to comment.