Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

No way to correctly reset weights of a model with spectral norm #25092

Open
bartwojcik opened this issue Aug 23, 2019 · 9 comments
Open

No way to correctly reset weights of a model with spectral norm #25092

bartwojcik opened this issue Aug 23, 2019 · 9 comments
Labels
has workaround module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bartwojcik
Copy link

bartwojcik commented Aug 23, 2019

I didn't find any similar issue (with spectral norm), so excuse me if I file a duplicate.

馃悰 Bug

There is no way to fully reset the weights (buffers?) of a model with spectral norm. When I reuse my model I get good looking samples almost instantly on any full MNIST 30-epoch training run other than the first, which I presume is caused by how the spectral norm is implemented. There is no reset_parameters for spectral norm (as it even isn't implemented as a layer). Removing and adding using remove_spectral_norm and spectral_norm does not work as a workaround, as I get:

RuntimeError: Unexpected key in metadata['spectral_norm']: weight.version

I know I should create a new model as it isn't that costly for each run, but this should be possible without copying.

To Reproduce

Steps to reproduce the behavior:

  1. Use a simple GAN model with SN in both the discriminator and generator.
  2. Train this model, save images while training.
  3. Reset weights / Apply init on the model.
  4. Train this model again, save images while training.

My init function:

def init_weights(m):
    if hasattr(m, 'reset_parameters'):
        m.reset_parameters()
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(0.0, 0.02)
        if m.bias is not None:
            m.bias.data.fill_(0)
    elif isinstance(m, nn.BatchNorm2d):
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
    elif isinstance(m, nn.Linear):
        m.weight.data.normal_(0.0, 0.1)
        if m.bias is not None:
            m.bias.data.fill_(0)

Expected behavior

I would expect my model to be fully reset. Especially if that is the "recommended" way: https://discuss.pytorch.org/t/cross-validation-model-reset/21176

Environment

PyTorch version: 1.2.0
Is debug build: No
CUDA used to build PyTorch: 10.0.130

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.11) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: GeForce RTX 2080 Ti
GPU 1: GeForce RTX 2080 Ti

Nvidia driver version: 430.34
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.16.4
[pip] torch==1.2.0
[pip] torchvision==0.4.0
[conda] blas 1.0 mkl
[conda] mkl 2019.4 243
[conda] mkl_fft 1.0.12 py37ha843d7b_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.2.0 pypi_0 pypi
[conda] torchvision 0.4.0 pypi_0 pypi

cc @ezyang @gchanan @ssnl

@izdeby izdeby added high priority triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Aug 23, 2019
@ezyang ezyang added the module: nn Related to torch.nn label Aug 23, 2019
@ezyang
Copy link
Contributor

ezyang commented Aug 23, 2019

I agree, this sounds like a bug. Removing and then readding the spectral norm should work.

@bartwojcik
Copy link
Author

I would argue that reset_parameters() in init should cover that, as it always will be unintuitive for new users if we try to "reset" the model and something stays the same. Removing and adding the spectral norm looks like a workaround for me, but I admit I might not get the whole picture how SN should be implemented. Note that BN has reset_parameters implemented and it resets the running mean estimate too.

@ssnl
Copy link
Collaborator

ssnl commented Aug 24, 2019

Could you give a script of how you remove and readd the spectral norm hook? It seems to work for me:

In [7]: c = torch.nn.Conv2d(3, 64, 3, 3)

In [8]: csn = torch.nn.utils.spectral_norm(c)

In [9]: c = torch.nn.utils.remove_spectral_norm(csn)

In [10]: c
Out[10]: Conv2d(3, 64, kernel_size=(3, 3), stride=(3, 3))

In [11]: c.reset_parameters()

In [12]: csn = torch.nn.utils.spectral_norm(c)

@RuiShu
Copy link

RuiShu commented Aug 31, 2019

I think if we wish to reset the model parameters without relying on the work around of (remove spectral norm-reset parameters-reapply spectral norm), we'll either have to:

  1. Re-write all the spectral normalizable layers to be aware of spectral norm so that the layer knows to reinit layer.weight_orig instead of layer.weight, or
  2. Re-write SpectralNorm as a fully-fledged nn.Module that borrows the layer.weight and layer.reset_parameters, and so that model.apply(weights_init) will reach SpectralNorm and thus get called and applied to layer.weight.

No. 1 is safer as it allows us to potentially establish a contract between a spectral normalizable layer and SpectralNorm so that all attributes exposed to the end-user will work as intended.

No. 2 is easier to implement but may still break on some corner cases.

@jbschlosser
Copy link
Contributor

Was this addressed by #57784? @lezcano @albanD @soulitzer

@soulitzer
Copy link
Contributor

No, but I think it should be pretty easy to do now. We'd just need to implement a .reset_parameters function that resets the buffers u and v.

@lezcano
Copy link
Collaborator

lezcano commented May 19, 2021

Going through the OP, I see that now doing remove and adding spectral_norm does work with the new implementation.

We could of course implement a reset_parameters function. Only that in this case, it'd be good if it accepted as a parameter the matrix it's being instantiated on, to be able to initialise the u and v properly.

@okbalefthanded
Copy link

Any progress on solving the issue the new parametrization implementation ?

@lezcano
Copy link
Collaborator

lezcano commented Apr 18, 2023

As discussed in #25092 (comment), a simple workaround would be to remove the the module and set it again. see e.g. https://pytorch.org/tutorials/intermediate/parametrizations.html#removing-parametrizations.

Now, we would accept a PR that implements this function. It should be fairly straightforward to implement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
has workaround module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants