# Layers

> To build and initialize NN layers.

In [None]:
#| default_exp layers

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore.all import *
from fastai.torch_basics import *
from fastai.callback.hook import Hook

## init

In [None]:
#| export
_fc_conv_filter = (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d,
                   nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)

def lambda_init(m: nn.Module, func: callable=lambda w,b: (nn.init.kaiming_normal_(w), nn.init.zeros_(b))):
    "Initialize the `weight` and `bias` of a model `m` with `func`."
    for l in m.modules():
        if isinstance(l, _fc_conv_filter):
            if l.bias is None: func(l.weight, torch.empty(1))
            else:              func(l.weight, l.bias)

In [None]:
#| export
def lsuv_init(m: nn.Module, # model
              xb: torch.Tensor, # mini-batch input
              tol: float=0.01, # tolerance
              n_iter: int=10, # maximum number of iterations
              verbose: bool=False, # print out details
             ):
    "Refer to [All you need is a good init](https://arxiv.org/abs/1511.06422)."
    xb = xb.cpu()
    m.to(xb.device)

    # orthogonal init
    lambda_init(m, lambda w,b: (nn.init.orthogonal_(w), nn.init.zeros_(b)))

    # LSUV init
    m.eval()
    for l in m.modules():
        if isinstance(l, _fc_conv_filter):
            n = 0
            h = Hook(l, lambda m,i,o: (o.mean().item(), o.std().item()), cpu=True)
            with torch.inference_mode():
                while (m(xb) is not None
                       and ((l.bias is not None and abs(h.stored[0] - 0.) > tol) or abs(h.stored[1] - 1.) > tol)
                       and n < n_iter):
                    l.weight /= (h.stored[1] + 1e-8)
                    if l.bias is not None: l.bias -= h.stored[0]
                    n += 1
            if verbose: print(f"{str(l):80}| took {n} iterations, mean={h.stored[0]:7.4f}, std={h.stored[1]:.4f}")
            h.remove()

The `xb` mini-batch is used to estimate the statistics (mean and std) for scaling the weights, similar to `BatchNorm`, but only for initialization. The batch size of `xb` could be different to the actual size used during training.

In [None]:
# test model
tst_model = nn.Sequential(nn.Conv2d(3,10,3), nn.ReLU(),
                          nn.Conv2d(10,10,1, bias=False), nn.LeakyReLU(),
                          nn.Flatten(), nn.Linear(640,10))

In [None]:
xb = torch.randn(50,3,10,10)
lsuv_init(tst_model, xb, verbose=True)

Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))                                | took 0 iterations, mean= 0.0023, std=0.9994
Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)                   | took 1 iterations, mean=-0.1563, std=1.0000
Linear(in_features=640, out_features=10, bias=True)                             | took 2 iterations, mean=-0.0000, std=1.0000


In [None]:
#| hide
from fastai.callback.hook import Hooks

In [None]:
#| hide
def print_forward_stats(mod, xb):
    _print_stats = lambda m,i,o: print(f'{type(m).__name__:12}| mean={o.mean():7.4f}, std={o.std():.4f}')
    hooks = Hooks(mod, _print_stats)
    with torch.inference_mode():
        mod(xb)
    hooks.remove()

In [None]:
print_forward_stats(tst_model, torch.randn(50,3,10,10))

Conv2d      | mean=-0.0115, std=1.0142
ReLU        | mean= 0.3987, std=0.5902
Conv2d      | mean=-0.1626, std=1.0073
LeakyReLU   | mean= 0.3056, std=0.5329
Flatten     | mean= 0.3056, std=0.5329
Linear      | mean=-0.0253, std=0.9501


In [None]:
#| export
def default_init(m: nn.Module, # model
                 normal: bool=True, # use normal distribution
                 verbose: bool=False, # print out details
                ):
    "Initialize weights of `nn.Linear` and `nn.ConvXd` using Xavier's or Kaiming's method; zero biases; custom gains."
    if normal: xavier = nn.init.xavier_normal_;  kaiming = nn.init.kaiming_normal_
    else:      xavier = nn.init.xavier_uniform_; kaiming = nn.init.kaiming_uniform_
    _txt = 'normal' if normal else 'uniform'
    
    _actn_filter = (nn.Tanh, nn.Sigmoid, nn.Softplus, nn.Softsign,
                    nn.ReLU, nn.LeakyReLU, nn.SiLU, nn.GELU, nn.ELU)
    _m = list(filter(lambda o: isinstance(o, _fc_conv_filter + _actn_filter), m.modules()))
    
    for l,lm in zip(_m, _m[0:1]+_m[:-1]): # l-th and (l-1)-th layers
        if isinstance(l, _fc_conv_filter):
            # ReLU and its variants
            if isinstance(lm, (nn.ReLU, nn.LeakyReLU, nn.SiLU, nn.GELU, nn.Softsign)):
                gain = lm.negative_slope if isinstance(lm, nn.LeakyReLU) else 0.
                kaiming(l.weight, gain, nonlinearity='leaky_relu')
                if verbose: print(f"{str(l):80}| He_{_txt}, negative_slope={gain}")
            
            # Custom gains by trial and error
            elif isinstance(lm, (nn.Tanh, nn.Sigmoid, nn.Softplus, nn.ELU)):
                gain = 1.79 if isinstance(lm, nn.Sigmoid) else 1.17
                xavier(l.weight, gain)
                if verbose: print(f"{str(l):80}| Xavier_{_txt}, gain={gain}")
            else:
                xavier(l.weight)
                if verbose: print(f"{str(l):80}| Xavier_{_txt}, gain=1.")
            if l.bias is not None: nn.init.zeros_(l.bias)

In [None]:
default_init(tst_model,verbose=True)

Conv2d(3, 10, kernel_size=(3, 3), stride=(1, 1))                                | Xavier_normal, gain=1.
Conv2d(10, 10, kernel_size=(1, 1), stride=(1, 1), bias=False)                   | He_normal, negative_slope=0.0
Linear(in_features=640, out_features=10, bias=True)                             | He_normal, negative_slope=0.01


In [None]:
print_forward_stats(tst_model, torch.randn(50,3,10,10))

Conv2d      | mean=-0.0086, std=0.6737
ReLU        | mean= 0.2621, std=0.3903
Conv2d      | mean=-0.0558, std=0.6318
LeakyReLU   | mean= 0.2018, std=0.3607
Flatten     | mean= 0.2018, std=0.3607
Linear      | mean=-0.0345, std=0.5574


In [None]:
#| export
def rai_init(m: nn.Module):
    "Randomized asymmetric initializer. \
    Refer to [Dying ReLU and Initialization: Theory and Numerical Examples](https://arxiv.org/abs/1903.06733)"
    _is_first = True
    
    for l in m.modules():
        if isinstance(l, _fc_conv_filter):
            if _is_first:
                nn.init.kaiming_normal_(l.weight)
                if l.bias is not None: nn.init.zeros_(l.bias)
                _is_first = False; continue
            # Get correct fan_in and fan_out
            if l.weight.ndim == 2:
                fan_out,fan_in = l.weight.shape
            else:
                fan_out = l.weight.shape[0]
                fan_in = np.prod(l.weight.shape[1:])
            # RAI
            V = torch.randn(fan_out, fan_in+1) * 0.6007 / (fan_in ** 0.5)
            for j in range(fan_out):
                k = np.random.randint(0, fan_in+1)
                V[j,k] = np.random.beta(2, 1)
            with torch.no_grad():
                l.weight.copy_(V[:, :-1].reshape(l.weight.shape))
                if l.bias is not None: l.bias.copy_(V[:, -1])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()