Skip to content

Commit

Permalink
adapt TFT tests
Browse files Browse the repository at this point in the history
  • Loading branch information
hrzn committed May 4, 2022
1 parent 0eeba80 commit f6711f5
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions darts/tests/models/forecasting/test_TFT.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,10 +282,14 @@ def helper_fit_predict(
series=series,
past_covariates=past_covariates,
future_covariates=future_covariates,
num_samples=100,
num_samples=(100 if model._is_probabilistic() else 1),
)

if isinstance(y_hat, TimeSeries):
y_hat = y_hat.quantile_timeseries(0.5)
y_hat = y_hat.quantile_timeseries(0.5) if y_hat.n_samples > 1 else y_hat
else:
y_hat = [ts.quantile_timeseries(0.5) for ts in y_hat]
y_hat = [
ts.quantile_timeseries(0.5) if ts.n_samples > 1 else ts
for ts in y_hat
]
return y_hat

0 comments on commit f6711f5

Please sign in to comment.