New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
EpochScoring caching does not work in every case #552
Comments
Thanks for the detailed report. This is true, we hadn't considered the case where the pure iteration over the dataset could be a bottleneck. The whole caching part is somewhat delicate and we tried a bunch of different solutions, all with some benefits and drawbacks. My first idea for this problem would be to patch
I'm not exactly sure what you mean. I think it's important that we can reuse as much of the existing infrastructure as possible, which is why went with the "pretending to predict" route in the first place. If you have an example that implements your proposal, we can discuss this. |
The nasty workaround I did just for the sake of forcing this to work is the following: class FastEpochScoring(EpochScoring):
def __init__(self, scoring, multi_class=True, lower_is_better=True, on_train=False, name=None,
target_extractor=to_numpy,
use_caching=True):
super().__init__(scoring, lower_is_better, on_train, name, target_extractor, use_caching)
self.multi_class = multi_class
def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
if not self.use_caching:
super().on_epoch_end(net, dataset_train, dataset_valid, **kwargs)
return
X_test, y_test, y_pred = self.get_test_data(dataset_train, dataset_valid)
if X_test is None:
return
y_pred_flat = torch.cat([torch.argmax(y.cpu(), dim=1) for y in y_pred]).numpy()
score = self._calculate_score(net, y_pred_flat, y_test)
self._record_score(net.history, score)
def _calculate_score(self, net, y_pred_flat, y_test):
metric = check_scoring(net, self.scoring_)._score_func
metric_kw = {}
if self.multi_class and "average" in getfullargspec(metric).args:
metric_kw["average"] = "macro"
return metric(y_test, y_pred_flat, **metric_kw) Maybe we can work from here? Another Idea I had is to provide "fake" dataset during your "pretending to predict" phase. |
I implemented a potential fix for your issue in #557. If you have time, you could test if this works for you. I went with my proposal since it sticks closer to the current solution and I want to be as backwards compatible as possible. Changing the caching mechanism could have some unintended side-effects which are hard to predict. |
@marrrcin I consider this issue fixed for now. Should you still encounter this issue in the future, feel free to re-open the ticket. |
@BenjaminBossan I've tested your implementation from current master (commit |
I wanted to use skorch
EpochScoring
with dataset ~100GB stored as files in GCS (or S3 does not really matter here). Dataset needs to be loaded on the fly due to memory constraints, so in every epoch, data is downloaded again.Reproduction steps:
train_ds
,val_ds
)NeuralNetClassifier
for your data, with the following args:EpochScoring
callbacks (i.e for Precision, Recall and F1) for both training and validation, with cachingnetwork.fit(train_ds, y=None)
Expected result:
EpochScoring
runs fast, because data fory
andy_pred
should be already cached.Actual result:
EpochScoring
is slow, although the cache is enabled. The cause of it is that it usednetwork.predict
under the hood, which does the iteration across the whole validation dataset again. So if I put 6 metrics in callbacks, I had to loop through the whole dataset 6 times.I think that
EpochScoring
should call scorers directly, when caching is enabled, because the case above shows that overwritinginfer
at the time of running the callback does not cover every case properly and the provided cache is useless in it.The text was updated successfully, but these errors were encountered: