Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EpochScoring caching does not work in every case #552

Closed
marrrcin opened this issue Nov 7, 2019 · 5 comments
Closed

EpochScoring caching does not work in every case #552

marrrcin opened this issue Nov 7, 2019 · 5 comments

Comments

@marrrcin
Copy link

marrrcin commented Nov 7, 2019

I wanted to use skorch EpochScoring with dataset ~100GB stored as files in GCS (or S3 does not really matter here). Dataset needs to be loaded on the fly due to memory constraints, so in every epoch, data is downloaded again.

Reproduction steps:

  1. Implement Dataset (PyTorch) that loads the data on the fly from external source (HTTP/GCS/S3)
  2. Create Dataset(s) for training and validation (train_ds, val_ds)
  3. Create NeuralNetClassifier for your data, with the following args:
train_split=predefined_split(val_ds),
  1. Add EpochScoring callbacks (i.e for Precision, Recall and F1) for both training and validation, with caching
  2. Run network.fit(train_ds, y=None)

Expected result:
EpochScoring runs fast, because data for y and y_pred should be already cached.

Actual result:
EpochScoring is slow, although the cache is enabled. The cause of it is that it used network.predict under the hood, which does the iteration across the whole validation dataset again. So if I put 6 metrics in callbacks, I had to loop through the whole dataset 6 times.

I think that EpochScoring should call scorers directly, when caching is enabled, because the case above shows that overwriting infer at the time of running the callback does not cover every case properly and the provided cache is useless in it.

@BenjaminBossan
Copy link
Collaborator

Thanks for the detailed report. This is true, we hadn't considered the case where the pure iteration over the dataset could be a bottleneck.

The whole caching part is somewhat delicate and we tried a bunch of different solutions, all with some benefits and drawbacks. My first idea for this problem would be to patch forward_iter instead of infer. It is a bit more involved but could work. @ottonemo, what do you think?

I think that EpochScoring should call scorers directly,

I'm not exactly sure what you mean. I think it's important that we can reuse as much of the existing infrastructure as possible, which is why went with the "pretending to predict" route in the first place. If you have an example that implements your proposal, we can discuss this.

@marrrcin
Copy link
Author

marrrcin commented Nov 9, 2019

The nasty workaround I did just for the sake of forcing this to work is the following:

class FastEpochScoring(EpochScoring):
    def __init__(self, scoring, multi_class=True, lower_is_better=True, on_train=False, name=None,
                 target_extractor=to_numpy,
                 use_caching=True):
        super().__init__(scoring, lower_is_better, on_train, name, target_extractor, use_caching)
        self.multi_class = multi_class

    def on_epoch_end(self, net, dataset_train, dataset_valid, **kwargs):
        if not self.use_caching:
            super().on_epoch_end(net, dataset_train, dataset_valid, **kwargs)
            return

        X_test, y_test, y_pred = self.get_test_data(dataset_train, dataset_valid)
        if X_test is None:
            return

        y_pred_flat = torch.cat([torch.argmax(y.cpu(), dim=1) for y in y_pred]).numpy()

        score = self._calculate_score(net, y_pred_flat, y_test)

        self._record_score(net.history, score)

    def _calculate_score(self, net, y_pred_flat, y_test):
        metric = check_scoring(net, self.scoring_)._score_func
        metric_kw = {}
        if self.multi_class and "average" in getfullargspec(metric).args:
            metric_kw["average"] = "macro"
        return metric(y_test, y_pred_flat, **metric_kw)

Maybe we can work from here?


Another Idea I had is to provide "fake" dataset during your "pretending to predict" phase.

@BenjaminBossan
Copy link
Collaborator

@marrrcin

I implemented a potential fix for your issue in #557. If you have time, you could test if this works for you.

I went with my proposal since it sticks closer to the current solution and I want to be as backwards compatible as possible. Changing the caching mechanism could have some unintended side-effects which are hard to predict.

@BenjaminBossan
Copy link
Collaborator

@marrrcin I consider this issue fixed for now. Should you still encounter this issue in the future, feel free to re-open the ticket.

@marrrcin
Copy link
Author

@BenjaminBossan I've tested your implementation from current master (commit 09be626e74512124eb74c76b4cbabad4d3b1f274) and it seems to work well, great job, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants