Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CLEAN : rm duplicate code in fit_loop. #564

Merged
merged 4 commits into from Feb 10, 2020
Merged

CLEAN : rm duplicate code in fit_loop. #564

merged 4 commits into from Feb 10, 2020

Conversation

YannDubs
Copy link
Contributor

@YannDubs YannDubs commented Nov 26, 2019

For the third time, I found that I had to modify fit_loop (once for skipping some validation steps when doing few shot learning, an other time for a workaround of #245 ). Every time I'm a bit hesitant because it's a large function so my changes will likely break for new versions of skorch. I think this should be made a little cleaner + in addition there's a large chunk of duplicate code, which is bad practice. I think it should be split in 2 simple functions.

Nothing very important but it's a bit better. I also give epoch as argument because this is something I usually need and makes sense to give to a function that computes a single epoch (e.g. for logging).

@BenjaminBossan
Copy link
Collaborator

Thanks for the PR. This is not an easy one :)

Every time I'm a bit hesitant because it's a large function so my changes will likely break for new versions of skorch.

We really try to not make changes that affect the public API (such as fit_loop) in a backwards-incompatible way. Have you encountered such breakage?

there's a large chunk of duplicate code, which is bad practice

This was actually a very conscious decision we made. In most cases, I would agree that you should not repeat yourself. Here, we wanted to lay bare the structure of the fit loop without any indirection, which is why we chose not to use sub-methods for that.

Would your use cases have been easier with this change?

To argue in favor of the change: At the time we made the decision, the fit_loop was smaller -- we added a few lines since then, making it more unwieldy in the process. Therefore, I would consider making the change. @ottonemo @thomasjpfan what are your opinions on that?

I also give epoch as argument because this is something I usually need and makes sense to give to a function that computes a single epoch (e.g. for logging).

I wonder if that is really necessary. Can you not perform the logging within fit_loop, which knows the current epoch?

Other changes or potential changes:

  • I would argue for using a different method name than _single_epoch, since it doesn't quite convey the content. Also, I would make it public.
  • prfx should be renamed to prefix (consistent with the rest of the skorch code)
  • Should the step_fn be passed from the fit_loop?

@ottonemo
Copy link
Member

OK, my take on this is that

  1. it is hard to overwrite fit_loop as it is since there is a lot going on that is important and it is hard to keep track of everything
  2. there is a need to overwrite fit_loop for some use-cases

While I am in favor of fixing (1) to enable easy modification of the whole training loop I would also love to understand (2), i.e. @YannDubs what problems are you solving by overwriting fit_loop? Can you show us your workarounds so we can understand the problem a bit better?

I also agree with @BenjaminBossan that the fit loop as it currently is, is a deliberate choice and there is nothing inherently wrong with having a central point that outlines the complete training cycle, even if that means that there is minor code duplication. However, if we can retain the global structure while reducing complexity for the user, I'm all for it (and this PR goes in that direction).

@YannDubs
Copy link
Contributor Author

YannDubs commented Nov 26, 2019

We really try to not make changes that affect the public API (such as fit_loop) in a backwards-incompatible way. Have you encountered such breakage?

No I never have, but I'm usually more reluctant to modify large functions. This is something I really like in skorch, as there are usually a lot of specific and easy to understand functions that can be quickly modified, but I don't feel it is the case for fit_loop.

Would your use cases have been easier with this change?

The use cases I had were usually very specific, I'm not saying others will want that. Here are 2 usecases I currently have to deal with:

  1. When running few shot learning experiments, the training set is usually orders of magnitude smaller (100x-1000x) than validation (it's not a "realistic scenario" but often done in research :/ ). In such case I need to run much more epochs and I don't want to spend most of the training time computing validation scores. What I do is validate if the epoch number is in list(range(10)) + list(range(10,100,10)) + list(range(100,1000,100)) . The best way would be to implement a scoring callback that keeps tracks of the epoch number and decide whether or not to skip validation. I don't see how to do that as the callback cannot skip the validation loop (it can only skip the saving of the validation scores). In my code, I override _single_epoch to call the parents _single_epoch only if the epoch number should be saved.

  2. Solving Support for multiple criteria? #245 is actually quite tricky. As I really need to save multiple temporary losses (quite a lot of them) and want them in the history (but don't need batch scores), I end up storing every loss (sum over batches and counts) in a self.to_store = dict() for the criterion. I then use the following lines of code at the end to the overrided _single_epoch:

if hasattr(self.criterion_, "to_store"):
        for k, v in self.criterion_.to_store.items():
            with suppress(NotImplementedError):
                # pytorch raises NotImplementedError on wrong types
                self.history.record(prfx + "_" + k, (v["sum"] / v["count"]).item())
        self.criterion_.to_store = dict()

Note that this saves the losses for training and testing using the prefix. Although I would prefer using callbacks, I don't see how to do that as I don't have access to self.criterion_. I think the callback frameworks could be made much more powerful by giving them access to the trainer self. But that seems strange and might cause some circular issues in specific cases (although I don't have any in mind).

I'm really not saying that both of the aforementioned methods are good ways of doing it, quite on the contrary : these are ad hoc tricks to get the job done. But I still believe that using 2 functions makes it easier to these such things.

I wonder if that is really necessary. Can you not perform the logging within fit_loop, which knows the current epoch?

Not necessary.

Other changes or potential changes:

I'll make those changes if you all agree it's worth it. I'm not convinced about giving step_fn to the function. It would decrease the lines of code but I don't really see the gains, will we ever give other setp_fn ?

Can you show us your workarounds so we can understand the problem a bit better?

See answer above.

@BenjaminBossan
Copy link
Collaborator

Thanks for your detailed answer. A few comments:

1. I don't want to spend most of the training time computing validation scores.

Interesting one. My first reflex would probably be to not have any validation data at all (train_split=False) and then add a callback that has a reference to the predefined validation data, which only validates on specific epochs. Not sure if that would completely cover your problem (e.g. doesn't work with grid search).

Multiple criteria: That's a tough one, as can be seen from the linked discussion. We really need a better story for that but as I'm not working on any use case that would involve this, it's hard for me to come up with a practical solution.

Although I would prefer using callbacks, I don't see how to do that as I don't have access to self.criterion_

Not sure if I understand you correctly, but net is always passed as the first argument to each on_* method, so through this, you should have access to criterion_.

these are ad hoc tricks to get the job done

That was one of our hopes when we designed skorch: We knew that we couldn't cover all use cases, but we wanted to make sure to at least make most of them not too hard to implement. Even if it means using "ad hoc tricks".

will we ever give other setp_fn ?

I was wondering about exactly that. Just a very wild guess: training a net that consists of separate components like a GAN? What would we lose by making the change?

Regarding the name, how about something as plain as fit_one_epoch?

I'll make those changes if you all agree it's worth it.

At the moment, I'm inclined to say yes, but since it's a fundamental change, I'm conservative here. I would like to hear @thomasjpfan's opinion on that too.

@YannDubs
Copy link
Contributor Author

YannDubs commented Nov 26, 2019

Not sure if I understand you correctly, but net is always passed as the first argument to each on_* method, so through this, you should have access to criterion_.

That's a really important point that I did not understand (I thought the first argument was module_ for some reason). Then the second usecase can be solved with a callback.

Interesting one. My first reflex would probably be to not have any validation data at all (train_split=False) and then add a callback that has a reference to the predefined validation data, which only validates on specific epochs.

That would work. It requires quite a lot of additional code (as I have to implement the validation rather than just skip it), but it's probably cleaner.

That was one of our hopes when we designed skorch

And you've done a great job at doing just that, that's why I thought such a PR might be in line with your usual approach

I was wondering about exactly that. Just a very wild guess: training a net that consists of separate components like a GAN? What would we lose by making the change?

You convinced me.

Regarding the name, how about something as plain as fit_one_epoch?

I thought about it but it seems strange to have the word fit when validating.

@BenjaminBossan
Copy link
Collaborator

I thought about it but it seems strange to have the word fit when validating.

True. Is there a word for either fit or validate?

@BenjaminBossan
Copy link
Collaborator

@thomasjpfan any opinion on this?

Copy link
Member

@thomasjpfan thomasjpfan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping validation based on epoch looks like a valid use case.

I am +1 with this PR with a small update suggested in my review. I think with the suggestion, this will reduce the complexity of fit_loop.

skorch/net.py Outdated
return self

def _single_epoch(self, dataset, training, epoch, **fit_params):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can update this signature to

def _single_epoch(self, dataset, training, prefix, 
                  step_fn, epoch, **fit_params):
    ...

And let the caller pass the arguments in:

 for epoch in range(epochs):
    self.notify('on_epoch_begin', **on_epoch_kwargs)
    self._single_epoch(dataset_train, training=True, prefix="train",
					   step_fn=self.train_step, epoch=epoch, **fit_params)

    if dataset_valid is not None:
        self._single_epoch(dataset_valid, training=False, prefix="valid",
					       step_fn=self.validation_step, epoch=epoch, **fit_params)
    self.notify("on_epoch_end", **on_epoch_kwargs)

@YannDubs
Copy link
Contributor Author

YannDubs commented Dec 5, 2019

I made the changes you asked for and added documentation.

@BenjaminBossan
Copy link
Collaborator

Just FYI, the failing tests are caused by changes in sklearn 0.22 and not by changes in this PR. We'll have to fix those tests first before we can make progress on this one.

@BenjaminBossan
Copy link
Collaborator

I still wonder what a good name for the new method could be, when seeing single_epoch, I wouldn't have guessed what it does.

@YannDubs
Copy link
Contributor Author

YannDubs commented Dec 6, 2019

do you prefer run_epoch ? I can also go with fit_epoch if you think it's better. I just find it strange to calling it fit in the validation case.

@BenjaminBossan
Copy link
Collaborator

run_epoch epoch is definitely better, or run_one_epoch/run_single_epoch. fit_epoch is not so good for the reason you mentioned.

@ottonemo ottonemo added this to In progress in 0.8.0 via automation Dec 17, 2019
@ottonemo
Copy link
Member

The change LGTM. Just checking, does this solve your initial problem efficiently @YannDubs?

I'm assuming you would implement a more efficient validation by adding something like this?

def run_single_epoch(self, dataset, training, prefix, step_fn, **fit_params):
    if training or (not training and len(self.history) % 10 == 0):
        return super().run_single_epoch(dataset, training, prefix, step_fn, **fit_params)
    # not reached

I wonder if it makes sense to supply the epoch counter as an argument as well.

@YannDubs
Copy link
Contributor Author

@ottonemo Yes this is what I was doing, which is pretty efficient in terms of lines of code.

Passing in the epoch argument is definitely cleaner for my usecase, but I don't have a strong feeling about it as it's easy to recover (as shown in your snippet) and is not used by default ...

skorch/net.py Outdated Show resolved Hide resolved
skorch/net.py Outdated Show resolved Hide resolved
@BenjaminBossan
Copy link
Collaborator

I wonder if it makes sense to supply the epoch counter as an argument as well.

If you'd want to do that, this wouldn't work:

- for _ in range(epochs):
+ for epoch in range(epochs):

The reason is that epochs might not always start with 0 (e.g. using partial_fit). So if you want to change it, please take this into account.

At a later point, we might want to add an example into the FAQ on how to skip validation as shown above.

@BenjaminBossan
Copy link
Collaborator

@YannDubs Any progress here?

Co-Authored-By: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@YannDubs
Copy link
Contributor Author

@BenjaminBossan I think this should be merge, was there anything else you wanted me to do ?

@BenjaminBossan BenjaminBossan merged commit d47357a into skorch-dev:master Feb 10, 2020
0.8.0 automation moved this from In progress to Done Feb 10, 2020
@BenjaminBossan
Copy link
Collaborator

Thanks @YannDubs for this great addition.

Now I was a bit too fast with merging, I think we should mention the changes in the CHANGES.md. Could you provide this in a separate PR?

sthagen added a commit to sthagen/skorch-dev-skorch that referenced this pull request Feb 11, 2020
CLEAN : rm duplicate code in fit_loop. (skorch-dev#564)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
No open projects
0.8.0
  
Done
Development

Successfully merging this pull request may close these issues.

None yet

4 participants