Skip to content

Commit

Permalink
Fix/predict single value (#108)
Browse files Browse the repository at this point in the history
* fix(TorchForecastingModel): solved bug at TorchForecastingModel.predict(n) with n = 1

* feature(testing): added tests for length 1 predictions for RNN and TCN, set torch random seed for TCN test

Co-authored-by: pennfranc <flaessig@student.ethz.ch>
Co-authored-by: TheMP <marek.pasieka@gmail.com>
  • Loading branch information
3 people authored and hrzn committed Jun 25, 2020
1 parent 9d2bcca commit e414c73
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
3 changes: 1 addition & 2 deletions darts/models/torch_forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,7 @@ def predict(self, n: int) -> TimeSeries:
pred_in[:, -1, :] = out[:, self.first_prediction_index]
test_out.append(out.cpu().detach().numpy()[0, self.first_prediction_index])
test_out = np.stack(test_out)

return self._build_forecast_series(test_out.squeeze())
return self._build_forecast_series(test_out.squeeze(1))

@property
def first_prediction_index(self) -> int:
Expand Down
4 changes: 4 additions & 0 deletions darts/tests/test_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ def test_fit(self):
pred3 = model3.predict(n=6)
self.assertNotEqual(sum(pred1.values() - pred3.values()), 0.)

# test short predict
pred4 = model3.predict(n=1)
self.assertEqual(len(pred4), 1)

shutil.rmtree('.darts')
5 changes: 5 additions & 0 deletions darts/tests/test_TCN.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@ def test_fit(self):
pred2 = model2.predict(n=2).values()[0]
self.assertTrue(abs(pred2 - 10) < abs(pred - 10))

# test short predict
pred3 = model2.predict(n=1)
self.assertEqual(len(pred3), 1)

def test_coverage(self):
torch.manual_seed(0)
input_lengths = range(20, 50)
kernel_sizes = range(2, 5)
dilation_bases = range(2, 5)
Expand Down

0 comments on commit e414c73

Please sign in to comment.