Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow overriding optimiser step / zero grad action #106

Closed
soraxas opened this issue Aug 13, 2019 · 15 comments

Comments

@soraxas
Copy link

commented Aug 13, 2019

Currently the logic of optimiser.step() and optimiser.zero_grad() are hard coded in the trainer, but sometimes it would be benificial to NOT zero_grad() or perform it at an arbitrary iteration (e.g. For RNN).
This is also related to #29 of implementing GAN in lightning.

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@soraxas good suggestion. Can you propose an alternative set up? would love to see other options for generalizing.

Maybe some hooks or trainer flags?

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

would #107 work?

The new approach is as follows:

Case 1:

Regular training, nothing needs to be changed.

Case 2:

GAN training:

  1. Pass in 2 optimizers (already supported).
  2. training_step will now have a optimizer_idx arg so the training_step can be adjusted accordingly.
def training_step(self, batch, batch_nb, optimizer_i)
  1. if you want to .step in different intervals, then implement the hook
    def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
        """
        Do something instead of the standard optimizer behavior
        :param epoch_nb:
        :param batch_nb:
        :param optimizer:
        :param optimizer_i:
        :return:
        """
        optimizer.step()

        # clear gradients
        optimizer.zero_grad()

For example in the GAN case maybe you do:

    def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
        """
        Do something instead of the standard optimizer behavior
        :param epoch_nb:
        :param batch_nb:
        :param optimizer:
        :param optimizer_i:
        :return:
        """
        # update generator opt every 2 steps
        if optimizer_i == 0:
            if batch_nb % 2 == 0 :
                optimizer.step()
                optimizer.zero_grad()
       
        # update discriminator opt every 4 steps
        if optimizer_i == 1:
            if batch_nb % 4 == 0 :
                optimizer.step()
                optimizer.zero_grad()
@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@soraxas. this fix is on master now. try it out and let me know if it solves your issue.

use case 2 for a gan

@AlphabetMan

This comment has been minimized.

Copy link

commented Aug 13, 2019

GANs are getting more and more complex. There might be more than 2 backward calls in one training step. I suggest to define optional methods training_substep_1 or _2, _3 etc. which consist of these forward-backward passes and returns a loss. Substeps would be contained within training_step there all losses would be accumulated for logging.

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@AlphabetMan can you give me a pseudocode example? this pattern isn't limited to 2 optimizers, just used 2 here for illustration. You could easily do:

    def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
        """
        Do something instead of the standard optimizer behavior
        :param epoch_nb:
        :param batch_nb:
        :param optimizer:
        :param optimizer_i:
        :return:
        """
        # update generator opt every 2 steps
        if optimizer_i == 0:
            if batch_nb % 2 == 0 :
                optimizer.step()
                optimizer.zero_grad()
       
        # update discriminator opt every 4 steps
        if optimizer_i == 1:
            if batch_nb % 4 == 0 :
                optimizer.step()
                optimizer.zero_grad()

        # update other opt every 6 steps
        if optimizer_i == 2:
            if batch_nb % 6 == 0 :
                optimizer.step()
                optimizer.zero_grad()

        # ... support N optimizers
@williamFalcon

This comment has been minimized.

@AlphabetMan

This comment has been minimized.

Copy link

commented Aug 13, 2019

Optimizer indexing seems to be inevitable. However, I don't like the idea of skipping batches, even though it might not affect quality of the results. I have came up with this. This gets complicated rather quickly.

`class CoolModel(pl.LightningModule):

def __init__(self):
    super(CoolModel, self).__init__()
    # not the best model...
    self.discriminator = torch.nn.Linear(28 * 28, 10)
    self.generator = torch.nn.Linear(28 * 28, 10)
    self.loss1 = nn.MSELoss()
    self.loss2 = nn.BCELoss()

def forward(self, x):
    pass    

@MaybeSomeDecorator
def training_substep_1(self, batch):
    # on real
    output1 = self.discriminator(batch)
    loss1 = self.loss2(output1, torch.ones_like(output1))
    # on fake
    fake_image = self.generator(torch.randn(8,28*16))
    output2 = self.discriminator(fake_image.detach())
    loss2 = self.loss2(output2, torch.zeros_like(output2))
    loss = loss1 + loss2
    return {'loss': loss, 'optimizer_idx': 0, 'tensor': fake_image}

def training_step(self, batch, batch_nb):
    intermediate = self.training_substep_1(batch) # backward is called for discriminator
    output1 = self.discriminator
    loss = self.loss2(output1, torch.ones_like(output1)) 
    return {'loss': loss, 'optimizer_idx': 1} # backward is called for generator

def configure_optimizers(self):
    # REQUIRED
    return [torch.optim.Adam(self.discriminator.parameters(), lr=0.02), torch.optim.Adam(self.generator.parameters(), lr=0.02)]`
@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@AlphabetMan The current approach live on 0.4.5 is:

Training_step

def training_step(self, batch, batch_nb, optimizer_i):
    if optimizer_i == 0:
        # do generator stuff
    if optimizer_i == 1:
        # do discriminator stuff

Control optimizer behavior

def optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i):
        """
        Do something instead of the standard optimizer behavior
        :param epoch_nb:
        :param batch_nb:
        :param optimizer:
        :param optimizer_i:
        :return:
        """
        # update generator opt every 2 steps
        if optimizer_i == 0:
            if batch_nb % 2 == 0 :
                optimizer.step()
                optimizer.zero_grad()
       
        # update discriminator opt every 4 steps
        if optimizer_i == 1:
            if batch_nb % 4 == 0 :
                optimizer.step()
                optimizer.zero_grad()

And of course, return 2 optimizers

def configure_optimizers(self):
    # REQUIRED
    return [torch.optim.Adam(self.discriminator.parameters(), lr=0.02), torch.optim.Adam(self.generator.parameters(), lr=0.02)]`
@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@AlphabetMan thoughts on this approach currently available?

@AlphabetMan

This comment has been minimized.

Copy link

commented Aug 13, 2019

So if there are multiple optimizers for a single batch from dataloader trainer calls sequentially:
training_step(self, batch, batch_nb, optimizer_i=0); optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i=0); training_step(self, batch, batch_nb, optimizer_i=1); optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i=1); training_step(self, batch, batch_nb, optimizer_i=2); optimizer_step(self, epoch_nb, batch_nb, optimizer, optimizer_i=2);
etc, until run out of optimizers.
Do I understand this correctly?
If yes, then it works. It's just that discriminator is updated with detached tensor from generated fake image. Usually, then the same tensor is passed through discriminator again to update generator. But in this implementation, to update generator image has to be generated again. It's not efficient. Unless of course, we can save that generated tensor in unused attribute in every step, for example:
self.temporary_tensor,

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

yup... that's what happens right now!

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 13, 2019

@AlphabetMan would be super helpful if you put together a simple GAN example to:
A. make sure it makes sense.
B. have as an official example on doing GAN training with Lightning

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 14, 2019

@williamFalcon

This comment has been minimized.

Copy link
Owner

commented Aug 14, 2019

image

@AlphabetMan

This comment has been minimized.

Copy link

commented Aug 14, 2019

Very nice. There are multiple ways to define training step, any network could be updated before the other, but this example is clear and simple. GAN's really benefit from distributed training, so this framework is really helpful. I think any tutorial on deep nets on pytorch website should be implementable here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.