Skip to content

Commit

Permalink
add a unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed Mar 3, 2022
1 parent 7841cef commit cde0184
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions darts/tests/models/forecasting/test_probabilistic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,24 @@ def _get_avgs(series):
"The difference between the mean forecast and the mean series is larger "
"than expected on component 1 for distribution {}".format(lkl),
)

def test_stochastic_inputs(self):
model = RNNModel(input_chunk_length=5)
model.fit(self.constant_ts, epochs=2)

# build a stochastic series
target_vals = self.constant_ts.values()
stochastic_vals = np.random.normal(
loc=target_vals, scale=1.0, size=(len(self.constant_ts), 100)
)
stochastic_vals = np.expand_dims(stochastic_vals, axis=1)
stochastic_series = TimeSeries.from_times_and_values(
self.constant_ts.time_index, stochastic_vals
)

# A deterministic model forecasting a stochastic series
# should return stochastic samples
preds = [model.predict(series=stochastic_series, n=10) for _ in range(2)]

# random samples should differ
self.assertFalse(np.alltrue(preds[0].values() == preds[1].values()))

0 comments on commit cde0184

Please sign in to comment.