Skip to content

Commit

Permalink
Thread STS name into its joint distribution
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 427855341
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Feb 10, 2022
1 parent afbc1a5 commit 6e83921
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 31 deletions.
3 changes: 2 additions & 1 deletion tensorflow_probability/python/sts/structural_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,8 @@ def state_space_model_likelihood(**param_vals):
# Likelihood.
[('observed_time_series', state_space_model_likelihood)]),
use_vectorized_map=False,
batch_ndims=batch_ndims))
batch_ndims=batch_ndims,
name=self.name))

if observed_time_series is not None:
return joint_distribution.experimental_pin(
Expand Down
85 changes: 55 additions & 30 deletions tensorflow_probability/python/sts/structural_time_series_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ def test_prior_sample(self):
2,
] + param.prior.batch_shape.as_list() + param.prior.event_shape.as_list())

def test_joint_distribution_name(self):
model = self._build_sts(name='foo')
jd = model.joint_distribution(num_timesteps=5)
self.assertEqual('foo', jd.name)

def test_joint_distribution_log_prob(self):
model = self._build_sts(
# Dummy series to build the model with float64 priors. Working in
Expand Down Expand Up @@ -402,14 +407,17 @@ def test_add_component(self):
@test_util.test_all_tf_execution_regimes
class AutoregressiveTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
return Autoregressive(order=3, observed_time_series=observed_time_series)
def _build_sts(self, observed_time_series=None, **kwargs):
return Autoregressive(
order=3,
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class ARMATest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
one = 1.
if observed_time_series is not None:
observed_time_series = (
Expand All @@ -421,36 +429,42 @@ def _build_sts(self, observed_time_series=None):
ma_order=1,
integration_degree=0,
level_drift_prior=tfd.Normal(loc=one, scale=one),
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class ARIMATest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
return AutoregressiveIntegratedMovingAverage(
ar_order=1, ma_order=2, integration_degree=2,
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class LocalLevelTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
return LocalLevel(observed_time_series=observed_time_series)
def _build_sts(self, observed_time_series=None, **kwargs):
return LocalLevel(
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class LocalLinearTrendTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
return LocalLinearTrend(observed_time_series=observed_time_series)
def _build_sts(self, observed_time_series=None, **kwargs):
return LocalLinearTrend(
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SeasonalTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
# Note that a Seasonal model with `num_steps_per_season > 1` would have
# deterministic dependence between timesteps, so evaluating `log_prob` of an
# arbitrary time series leads to Cholesky decomposition errors unless the
Expand All @@ -461,79 +475,87 @@ def _build_sts(self, observed_time_series=None):
return Seasonal(num_seasons=7,
num_steps_per_season=1,
observed_time_series=observed_time_series,
constrain_mean_effect_to_zero=False)
constrain_mean_effect_to_zero=False,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SeasonalWithZeroMeanConstraintTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
return Seasonal(num_seasons=7,
num_steps_per_season=1,
observed_time_series=observed_time_series,
constrain_mean_effect_to_zero=True)
constrain_mean_effect_to_zero=True,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SeasonalWithMultipleStepsAndNoiseTest(test_util.TestCase,
_StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
day_of_week = tfp.sts.Seasonal(num_seasons=7,
num_steps_per_season=24,
allow_drift=False,
observed_time_series=observed_time_series,
name='day_of_week')
return tfp.sts.Sum(components=[day_of_week],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SemiLocalLinearTrendTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
return SemiLocalLinearTrend(observed_time_series=observed_time_series)
def _build_sts(self, observed_time_series=None, **kwargs):
return SemiLocalLinearTrend(
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SmoothSeasonalTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
return SmoothSeasonal(period=42,
frequency_multipliers=[1, 2, 4],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SmoothSeasonalWithNoDriftTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
smooth_seasonal = SmoothSeasonal(period=42,
frequency_multipliers=[1, 2, 4],
allow_drift=False,
observed_time_series=observed_time_series)
# The test harness doesn't like models with no parameters, so wrap with Sum.
return tfp.sts.Sum([smooth_seasonal],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SumTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
first_component = LocalLinearTrend(
observed_time_series=observed_time_series, name='first_component')
second_component = LocalLinearTrend(
observed_time_series=observed_time_series, name='second_component')
return Sum(
components=[first_component, second_component],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class LinearRegressionTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
max_timesteps = 100
num_features = 3

Expand All @@ -557,13 +579,14 @@ def _build_sts(self, observed_time_series=None):
max_timesteps, num_features).astype(dtype),
weights_prior=prior)
return Sum(components=[regression],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class SparseLinearRegressionTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
max_timesteps = 100
num_features = 3

Expand All @@ -584,13 +607,14 @@ def _build_sts(self, observed_time_series=None):
max_timesteps, num_features).astype(dtype),
weights_batch_shape=batch_shape)
return Sum(components=[regression],
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


@test_util.test_all_tf_execution_regimes
class DynamicLinearRegressionTest(test_util.TestCase, _StsTestHarness):

def _build_sts(self, observed_time_series=None):
def _build_sts(self, observed_time_series=None, **kwargs):
max_timesteps = 100
num_features = 3

Expand All @@ -604,7 +628,8 @@ def _build_sts(self, observed_time_series=None):
return DynamicLinearRegression(
design_matrix=np.random.randn(
max_timesteps, num_features).astype(dtype),
observed_time_series=observed_time_series)
observed_time_series=observed_time_series,
**kwargs)


if __name__ == '__main__':
Expand Down

0 comments on commit 6e83921

Please sign in to comment.