# VanillaNet Components

In [4]:
from torch import nn
import torch
import torch.nn.functional as F

In [3]:
from timm.models.layers import DropPath, trunc_normal_

## Activation
> Series Informed Activation Function


In [70]:
def fusing_conv_and_bn(kernel, bias, gamma, beta, running_mean, running_var, eps: float = 1e-6):
    """
    Combine the convolutional kernel and the batch normalization parameters.
    besides eps, all parameters are learnable parameters.

    kernel, bias are from the convolutional layer.
    gamma, beta, running_mean, running_var are from the batch normalization layer.

    The output is a new kernel and a new bias, which can be used to replace the original convolutional layer.
    """
    with torch.no_grad():
        std = (running_var + eps).sqrt()
        t = (gamma / std).reshape(-1, 1, 1, 1)
        return kernel * t, beta + (bias - running_mean) * gamma / std

In [61]:
class SIActivation(nn.Module):
    """
    Series informed activation.
    """
    def __init__(self, dim: int, act_num: int = 3, inference: bool = False, activation_module: nn.Module = nn.ReLU()):
        super().__init__()

        self.dim = dim
        self.act_num = act_num
        self.inference = inference

        # The weight is in size of (channel_in, 1, kernel_size, kernel_size)
        self.weight = nn.Parameter(torch.randn(dim, 1, act_num*2+1, act_num*2+1))
        self.activation_module = activation_module
        
        if inference:
            self.bias = nn.Parameter(torch.zeros(dim))
        else:
            self.bn = nn.BatchNorm2d(dim, eps=1e-6)

        trunc_normal_(self.weight, std=.02)

    def _fuse_bn_tensor(self, weight: nn.Module, bn: nn.Module):
        """
        Fuse the convolutional layer weights and the batchnorm layer into one layer.
        """
        new_kernel, new_bias = fusing_conv_and_bn(
            kernel=weight,
            bias=0, # we don't have bias before combine
            gamma=bn.weight,
            beta=bn.bias,
            running_mean=bn.running_mean,
            running_var=bn.running_var, eps=bn.eps)
        return new_kernel, new_bias
        
    def switch_to_inference(self):
        if self.inference:
            print("already in inference mode")
            return
        kernel, bias = self._fuse_bn_tensor(self.weight, self.bn)

        self.weight.data = kernel
        self.bias = nn.Parameter(bias)
        self.bn.to('cpu')
        self.__delattr__('bn')
        self.inference = True

    def forward(self, x: torch.Tensor):
        x = self.activation_module(x)
        if self.inference:
            return F.conv2d(x, self.weight, bias=self.bias, padding=self.act_num,
                groups=self.dim, #number of input channels
            )
        else:
            return self.bn(
                F.conv2d(x, self.weight, padding=self.act_num,
                    groups=self.dim, #number of input channels
                ))
        

In [66]:
si_leaky = SIActivation(64, 3, False, nn.LeakyReLU())

In [67]:
x = torch.randn(1, 64, 32, 32)

In [68]:
with torch.no_grad():
    print(f"is inference: {si_leaky.inference}")
    y_1 = si_leaky(x.clone())

si_leaky.switch_to_inference()

with torch.no_grad():
    print(f"is inference: {si_leaky.inference}")
    y_2 = si_leaky(x.clone())

is inference: False
is inference: True


In [69]:
(y_1 - y_2)

tensor([[[[ 0.5920, -0.1833,  0.2094,  ...,  1.1171,  1.1279, -0.2087],
          [ 0.2836, -0.8651,  0.1697,  ...,  0.5286,  0.3792,  0.3336],
          [ 1.3090, -0.4084,  0.0641,  ...,  0.6949,  1.4974,  0.8812],
          ...,
          [-0.4187, -1.1463, -0.2624,  ...,  0.2385, -0.6644,  0.9209],
          [ 0.0413,  0.1789, -1.1674,  ..., -0.5877,  0.1802, -0.5983],
          [ 0.2251, -0.6933, -0.5823,  ..., -0.9055, -0.1351,  0.3962]],

         [[-0.8529, -0.8245, -0.3242,  ..., -0.9185, -0.9937, -0.7371],
          [-0.3071, -0.0938, -0.4368,  ..., -0.2782, -0.9000, -0.9356],
          [-1.0432, -0.3383, -0.9051,  ..., -0.0265, -0.9231,  0.1820],
          ...,
          [-1.5265,  0.8418, -0.0559,  ...,  0.1747,  0.4012,  0.0962],
          [-0.1305,  0.0170,  0.4680,  ...,  0.4032,  0.8075,  1.1066],
          [ 0.6469,  0.3358,  0.9091,  ..., -0.1819,  0.6818,  0.2476]],

         [[-0.1031,  0.7024, -1.0838,  ..., -0.8862,  1.0748,  0.7174],
          [-0.5285, -0.2928, -

In [6]:
x = torch.randn(2, 64, 224, 224)

In [7]:
nn.Conv2d(64, 64, 3, padding=1).weight.shape

torch.Size([64, 64, 3, 3])

In [9]:
bn  = nn.BatchNorm2d(64, eps=1e-6)

In [12]:
bn.weight, bn.weight.shape

(Parameter containing:
 tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
         1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], requires_grad=True),
 torch.Size([64]))

In [14]:
bn.bias, bn.bias.shape

(Parameter containing:
 tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        requires_grad=True),
 torch.Size([64]))

### Batch Normalization

Under usual batch normalization the following is computed and updated, notice it involves 4 players (learnable parameters): $\gamma, \beta, \mu, \sigma^2$

#### Calculation & Update

Mean of the batch, easy to compute

$
\mu_B = \frac{1}{m} \sum_{i=1}^{m} x_i
$

The following is basically standard deviation of the batch

$
\sigma_B^2 = \frac{1}{m} \sum_{i=1}^{m} (x_i - \mu_B)^2
$

Normalize the batch, shift the distribution to zero mean and unit variance

$
\hat{x_i} = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}
$

Running mean and variance update (using mmentum to decay the running mean and variance)

$
\mu = \alpha \mu + (1 - \alpha) \mu_B$

$
\sigma^2 = \alpha \sigma^2 + (1 - \alpha) \sigma_B^2
$

Scale and shift the normalized batch

$
y_i = \gamma \hat{x_i} + \beta
$

In [15]:
x = bn(x)
x

tensor([[[[-8.9798e-01,  9.2157e-01, -9.0759e-01,  ...,  5.8390e-01,
            6.2553e-01,  3.3800e-01],
          [-4.8013e-01,  3.0036e-01, -4.2789e-01,  ...,  2.6584e-01,
            1.3161e+00,  2.2439e-01],
          [ 6.3771e-01, -1.6585e+00,  1.7861e-01,  ...,  2.4691e+00,
           -8.1982e-01, -1.3654e-01],
          ...,
          [ 1.1954e+00, -8.5129e-01, -5.4851e-01,  ..., -9.7521e-01,
           -1.1201e+00,  1.7426e-01],
          [-2.1337e-01,  6.2733e-01, -6.7771e-01,  ..., -9.2922e-01,
           -2.6842e+00,  1.1628e-01],
          [ 1.9230e-01, -2.1457e+00, -1.4950e-01,  ...,  4.6629e-01,
           -1.9567e+00, -6.1159e-02]],

         [[ 1.2826e+00, -6.5572e-01, -1.1191e-01,  ...,  1.3191e+00,
            3.3968e-01, -9.8841e-01],
          [ 1.7222e-01,  3.9568e-01, -8.2014e-01,  ...,  1.3924e+00,
           -1.1021e+00, -1.4651e+00],
          [-9.3531e-01,  6.5933e-01,  8.7128e-02,  ..., -1.1898e+00,
           -1.5951e+00,  1.3119e+00],
          ...,
     

In [17]:
x.mean(), x.std()

(tensor(-1.1879e-10, grad_fn=<MeanBackward0>),
 tensor(1.0000, grad_fn=<StdBackward0>))

In [18]:
bn.running_mean, bn.running_mean.shape

(tensor([ 2.6744e-04,  5.1428e-04,  4.7100e-04, -8.5480e-05, -1.2897e-04,
          2.7630e-04, -1.9318e-04,  2.5398e-05,  4.5987e-04,  6.8352e-04,
          6.1441e-05, -3.0484e-04, -1.5357e-04,  2.7439e-04,  1.8183e-04,
         -6.1147e-04, -5.4306e-05, -8.7659e-05, -6.8270e-04, -2.1659e-04,
          3.6032e-05,  5.3290e-05, -2.2876e-04, -2.5380e-05, -2.3229e-04,
         -1.1275e-05, -3.1075e-05,  6.5662e-05,  3.9315e-05,  3.7166e-04,
         -5.5835e-04, -1.9106e-04, -9.2222e-06,  2.0686e-04, -3.1161e-04,
         -1.9846e-05, -3.0557e-04, -1.1270e-04,  2.1897e-04, -2.0205e-04,
          2.1866e-04,  8.5770e-05, -1.2760e-04, -6.1896e-04, -2.0271e-04,
         -3.5679e-05, -2.4435e-04,  3.9585e-04, -2.7237e-04, -3.5685e-04,
         -3.3539e-04,  1.8854e-04,  1.9764e-04, -8.5256e-05,  3.4071e-04,
          2.7728e-04,  9.3718e-05, -4.8139e-04,  5.0605e-05, -5.1907e-04,
         -1.5690e-04, -2.1815e-04, -3.3568e-04,  1.0267e-04]),
 torch.Size([64]))

In [21]:
bn??

[0;31mSignature:[0m      [0mbn[0m[0;34m([0m[0;34m*[0m[0minput[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m           BatchNorm2d
[0;31mString form:[0m    BatchNorm2d(64, eps=1e-06, momentum=0.1, affine=True, track_running_stats=True)
[0;31mFile:[0m           ~/anaconda3/lib/python3.9/site-packages/torch/nn/modules/batchnorm.py
[0;31mSource:[0m        
[0;32mclass[0m [0mBatchNorm2d[0m[0;34m([0m[0m_BatchNorm[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;34mr"""Applies Batch Normalization over a 4D input (a mini-batch of 2D inputs[0m
[0;34m    with additional channel dimension) as described in the paper[0m
[0;34m    `Batch Normalization: Accelerating Deep Network Training by Reducing[0m
[0;34m    Internal Covariate Shift <https://arxiv.org/abs/1502.03167>`__ .[0m
[0;34m[0m
[0;34m    .. math::[0m
[0;34m[0m
[0;34m        y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \ga

In [22]:
nn.Conv2d(64, 64, 3, padding=1).

Parameter containing:
tensor([ 0.0052,  0.0118,  0.0075,  0.0313, -0.0141, -0.0022,  0.0135,  0.0336,
        -0.0182,  0.0095,  0.0362, -0.0291, -0.0096,  0.0136,  0.0333,  0.0364,
        -0.0148,  0.0255, -0.0406,  0.0132,  0.0178, -0.0404, -0.0210,  0.0030,
         0.0086,  0.0182,  0.0111,  0.0347,  0.0068, -0.0010,  0.0102, -0.0194,
         0.0312, -0.0225, -0.0306, -0.0191,  0.0356,  0.0318, -0.0408, -0.0364,
        -0.0085,  0.0140, -0.0076, -0.0202, -0.0031, -0.0083, -0.0202,  0.0381,
         0.0029,  0.0084, -0.0041, -0.0227,  0.0076,  0.0283,  0.0096, -0.0325,
        -0.0016, -0.0215,  0.0315,  0.0345,  0.0133, -0.0228, -0.0092,  0.0322],
       requires_grad=True)