-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tsformer pretrain question #18
Comments
Thanks for your question. |
Thanks for your reply, so the right test method is the test method in the file "base_tsf_runner.py",` def test(self):
` |
No, the @torch.no_grad()
@master_only
def test(self):
"""Evaluate the model.
Args:
train_epoch (int, optional): current epoch if in training process.
"""
for _, data in enumerate(self.test_data_loader):
forward_return = self.forward(data=data, epoch=None, iter_num=None, train=False)
# re-scale data
prediction_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[0], **self.scaler["args"])
real_value_rescaled = SCALER_REGISTRY.get(self.scaler["func"])(forward_return[1], **self.scaler["args"])
# metrics
for metric_name, metric_func in self.metrics.items():
metric_item = metric_func(prediction_rescaled, real_value_rescaled, null_val=self.null_val)
self.update_epoch_meter("test_"+metric_name, metric_item.item()) |
Thanks for your answering, I got your idea. : ) |
This bug is now fixed. Thanks again for your report! |
In the file tsformer_runner.py, it uses method "test" to do the test while training, however, I found that it only use the last batch of the test_dataloader. Is there something wrong with the for loop? I'm I wrong?
` @torch.no_grad()
@master_only
def test(self):
"""Evaluate the model.
The text was updated successfully, but these errors were encountered: