Skip to content

Wrong initialisation gain for SNNs/SELU #54055

@mrTsjolder

Description

@mrTsjolder

🐛 Bug

The default initialisation gain for the SELU activation function breaks SNNs, which explicitly require initial weights to have variance 1 / fan_in (rather than 0.75 / fan_in).

To Reproduce

Steps to reproduce the behaviour:

  1. Create multi-layer SNN
  2. Use default initialisation
  3. Print mean and variance in every layer
from torch import nn

def init(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_uniform_(m.weight, nonlinearity='selu')

layers = [30] + [100] * 20 + [10]
snn = nn.Sequential(*(
    nn.Sequential(
        nn.Linear(n_in, n_out, bias=False),
        nn.SELU(),
    ) for n_in, n_out in zip(layers[:-1], layers[1:])
))
snn.apply(init)

x = torch.randn(1000, 30)
for layer in snn:
    x = layer(x)
    print(f"{x.mean().item():.5f}, {x.var().item():.5f}")

Expected behavior

Normalised activations in every layer, i.e. zero mean and unit variance. This can be obtained by using nonlinearity='linear'.

Environment

  • PyTorch Version (e.g., 1.0): 1.8.0
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, source): conda
  • Build command you used (if compiling from source):
  • Python version: 3.8

Additional context

This issue was introduced with PR #50664 which aimed to tackle #24991. In this PR, an empirical value for the gain was chosen so that gradients "behave better" (i.e., do not grow or shrink that much). I already proposed to fix this with PR #53694, which would be BC breaking and therefore requires more incentive to be fixed. A discussion as to why it does not make sense to use the empirical value is included in my PR. The goal of this issue is therefore to provide the incentive to fix this bug by collecting upvotes for this issue.

TL;DR: Trying to collect upvotes to get this bug fixed

cc @albanD @mruberry @jbschlosser

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: initializationRelated to weight initialization on operatorsmodule: nnRelated to torch.nntriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions