Skip to content

Commit

Permalink
Fix params
Browse files Browse the repository at this point in the history
  • Loading branch information
bashtage committed Jul 15, 2020
1 parent 1f40dc5 commit c7ff04d
Showing 1 changed file with 76 additions and 1 deletion.
77 changes: 76 additions & 1 deletion statsmodels/tsa/holtwinters/model.py
Expand Up @@ -13,6 +13,7 @@
"""
from statsmodels.compat.pandas import deprecate_kwarg

import contextlib
from typing import Any, Hashable, Sequence
import warnings

Expand Down Expand Up @@ -166,7 +167,7 @@ class ExponentialSmoothing(TimeSeriesModel):
excluding the initial values if estimated. The parameters are ordered
[alpha, beta, gamma, phi]. Each element is a tuple of the form
(lower, upper). Default is (0.0001, 0.9999) for the level, trend, and
seasonal smoothing parameters and (0.8, 0.98) for the trend damping
seasonal smoothing parameters and (0.8, 0.995) for the trend damping
parameter.
Returns
Expand Down Expand Up @@ -265,10 +266,34 @@ def __init__(
self._estimate_level = estimated
self._estimate_trend = estimated and self.trend
self._estimate_seasonal = estimated and self.seasonal
self._bounds = self._check_bounds(bounds)
self._use_boxcox = use_boxcox
self._lambda = np.nan
self._y = self._boxcox()
self._initialize()
self._fixed_parameters = {}

def _check_bounds(self, bounds):
if bounds is None:
return
msg = (
"bounds must be a list of 2-element tuples of the form"
" (lb, ub) where lb < ub, lb>=0 and ub<=1"
)
try:
bounds_len = len(bounds) != 4
except:
raise TypeError(msg)
if bounds_len != 4:
raise TypeError(msg)
for bound in bounds:
if not isinstance(bound, tuple):
raise TypeError(msg)
if len(bound) != 2 or bound[0] >= bound[1]:
raise ValueError(msg)
if bound[0] < 0.0 or bound[1] > 1.0:
raise ValueError(msg)
return list(bounds)

def _boxcox(self):
if (
Expand All @@ -291,6 +316,53 @@ def _boxcox(self):
raise TypeError("use_boxcox must be True, False or a float.")
return y

@contextlib.contextmanager
def fix_params(self, values):
"""
Temporarily fix parameters for estimation
Parameters
----------
values : dict
Values to fix. The key is the parameter name and the value is the
fixed value.
Examples
--------
# TODO
"""
values = dict_like(values, "values")
valid_keys = ("smoothing_level",)
if self.has_trend:
valid_keys += ("smoothing_trend",)
if self.has_seasonal:
valid_keys += ("smoothing_seasonal",)
valid_keys += tuple(
[f"initial_seasonal.{i}" for i in range(self.seasonal_periods)]
)
if self.damped_trend:
valid_keys += ("damping_trend",)
if self._initialization_method in ("estimated", None):
extra_keys = [
key.replace("smoothing_", "initial_")
for key in valid_keys
if "smoothing_" in key
]
valid_keys += tuple(extra_keys)

for key in values:
if key not in valid_keys:
valid = ", ".join(valid_keys[:-1]) + ", and " + valid_keys[-1]
raise KeyError(
f"{key} if not allowed. Only {valid} are supported in this specification."
)

try:
self._fixed_parameters = values
yield
finally:
self._fixed_parameters = {}

def _initialize(self):
if self._initialization_method is None:
warnings.warn(
Expand Down Expand Up @@ -487,6 +559,9 @@ def _optimize_parameters(
(0.8, 0.995), # phi
]
bounds += [(None, None)] * m
if self._bounds is not None:
for i, loc in enumerate((0, 1, 2, 5)):
bounds[loc] = self._bounds[i]
hw_args = HoltWintersArgs(sel.astype(int), params, y, m, self.nobs)

if start_params is None and use_brute and np.any(sel[:3]):
Expand Down

0 comments on commit c7ff04d

Please sign in to comment.