Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

BatchNormalization meddles in finetuning schedule. #5

Closed
JohannesK14 opened this issue Oct 11, 2022 · 1 comment
Closed

BatchNormalization meddles in finetuning schedule. #5

JohannesK14 opened this issue Oct 11, 2022 · 1 comment
Labels
bug Something isn't working

Comments

@JohannesK14
Copy link

JohannesK14 commented Oct 11, 2022

馃悰 Bug

Hi @speediedan, thanks for creating this fine-tuning callback!

I want to fine-tune the torchvision.models.segmentation.deeplabv3_resnet50 dataset, which contains batch normalization parameters. As a first step, I would like to fine-tune in a simple two-step approach. Starting with the classification part and continue with the backbone. When I run the attached script finetuning-scheduler issues the following warning:

UserWarning: FinetuningScheduler configured the provided model to have 23 trainable parameters in phase 0 (the initial training phase) but the optimizer has subsequently been initialized with 129 trainable parameters. If you have manually added additional trainable parameters you may want to ensure the manually added new trainable parameters do not collide with the 182 parameters FinetuningScheduler has been scheduled to thaw in the provided schedule.

I think this is caused by freeze_before_training calling freeze which has a parameter train_bn which defaults to True and therefore sets requires_grad to True for every batch normalization parameter.

Downstream the code crashes because of parameters appearing in more than one parameter group (though, not in the attached sample, as the data does not fit the model, but it reproduces the initial bug).

To Reproduceimport torch

from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, LightningDataModule
from pytorch_lightning.cli import LightningArgumentParser, LightningCLI
from torchvision.models.segmentation import deeplabv3_resnet50


class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        
        self.model = deeplabv3_resnet50()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("train_loss", loss)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("valid_loss", loss)

    def test_step(self, batch, batch_idx):
        loss = self(batch).sum()
        self.log("test_loss", loss)

    def configure_optimizers(self):
        params = list(filter(lambda x: x.requires_grad, self.parameters()))
        optimizer = torch.optim.Adam(params, lr=0.1)

        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, factor=0.5, patience=5)
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "valid_loss",
            },
        }


class MyDataModule(LightningDataModule):

    def __init__(self):
        super().__init__()

    def train_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def val_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)

    def test_dataloader(self) -> DataLoader:
        return DataLoader(RandomDataset(32, 64), batch_size=2)


from finetuning_scheduler import FinetuningScheduler, FTSCheckpoint, FTSEarlyStopping


class BoringCLI(LightningCLI):

    def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
        parser.add_lightning_class_args(FinetuningScheduler, 'finetune_scheduler')
        parser.set_defaults({
            'finetune_scheduler.ft_schedule': {
                0: {
                    'params': ['model.classifier.*'],  # the parameters for each phase definition can be fully specified
                    'max_transition_epoch': 2,
                    'lr': 0.001,
                },
                1: {
                    'params': ['model.backbone.*'],
                    'lr': 0.001 
                }
            },
        })

        # EarlyStopping
        parser.add_lightning_class_args(FTSEarlyStopping, "early_stopping")
        early_stopping_defaults = {
            "early_stopping.monitor": "valid_loss",
            "early_stopping.patience": 99999,  # disable early_stopping
            "early_stopping.mode": "min",
            "early_stopping.min_delta": 0.01,
        }
        parser.set_defaults(early_stopping_defaults)

        # ModelCheckpoint
        parser.add_lightning_class_args(FTSCheckpoint, 'model_checkpoint')
        model_checkpoint_defaults = {
            "model_checkpoint.filename": "epoch{epoch}_val_loss{valid_loss:.2f}",
            "model_checkpoint.monitor": "valid_loss",
            "model_checkpoint.mode": "min",
            "model_checkpoint.every_n_epochs": 1,
            "model_checkpoint.save_top_k": 5,
            "model_checkpoint.auto_insert_metric_name": False,
            "model_checkpoint.save_last": True
        }
        parser.set_defaults(model_checkpoint_defaults)


if __name__ == "__main__":
    BoringCLI(BoringModel, MyDataModule, seed_everything_default=False, save_config_overwrite=True)

Expected behavior

Call freeze with train_bn=False, to avoid training the batch normalization parameters by default.

Environment

  • CUDA:
    - GPU:
    - NVIDIA GeForce RTX 3090
    - available: True
    - version: 11.3
  • Packages:
    - finetuning-scheduler: 0.2.3
    - numpy: 1.23.3
    - pyTorch_debug: False
    - pyTorch_version: 1.12.1
    - pytorch-lightning: 1.7.7
    - tqdm: 4.64.1
  • System:
    - OS: Linux
    - architecture:
    - 64bit
    - ELF
    - processor: x86_64
    - python: 3.10.6
    - version: #54~20.04.1-Ubuntu SMP Thu Sep 1 16:17:26 UTC 2022

Additional context

@JohannesK14 JohannesK14 added the bug Something isn't working label Oct 11, 2022
@speediedan
Copy link
Owner

Thanks for taking the time to file the issue @JohannesK14!

I've reproduced the issue, made the appropriate change in a02b864 and added the relevant test to avoid future reversions in this behavior.

The aforementioned commit should be included in either FTS 0.2.4 (if PL releases another patch version of 1.7) or will be in the first version of FTS 0.3 (0.3.0)

Your clear description and attention to detail in the reproduction really accelerated resolution, a model bug report! 馃殌 馃帀
Thanks again and glad you've found FTS of at a least some utility. Feel free to reach out anytime if you have other issues or want to share your use case. Best of luck with your work!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants