In [1]:
import torch 
import torch.nn as nn 

VGG has plain (single path) conv layers unlike ResNets which have residual layers(multi-path). Discussion on [`RepVGG: Making VGG-style ConvNets Great Again`](https://arxiv.org/pdf/2101.03697.pdf).


### VGG
$$
y = f(x)
$$

### ResNet 
$$
y = f(x)+x 
$$


> 1. multi-path networks has more memory requirement than single-path networks 

In case of `VGG` we just need to one variable at a time. x when passed to function `f` gives output of y.In `Resnet` we need to store x and then f(x) also (memory required is double) in memory, Later add both of them to get y. 


> 2. Even though multi-path networks have lesser number of `FLOPs`, they take more time compared to single path network

in the above scenerio, though addition is simple, `Memory access cost` is significant. Also a network with few large operators achieve high `data parallelism` than a network with large small operators (inception, Efficientnet) etc. 


Considering the above two cases, we feel that `VGG` like architectures are more efficient than `ResNet`. This is partially true but when trained on `ImageNet`, ResNet outperforms VGG, this is because of `vanishing gradient problem` discussed [here](https://medium.com/@14prakash/understanding-and-implementing-architectures-of-resnet-and-resnext-for-state-of-the-art-image-cf51669e1624)


So is there a way we can achieve `ResNet` like performance with `VGG` kind of architecuture?

### Fusing Multi-path `ResNet` architecuture during training to Single-path architecture during inference.

Lets take a ResNet like conv block. 

It applies three conv layers with Batch norm parallelly on input and adds all the inputs together. 

> 1. input tensor x 

In [2]:
x = torch.randn((3, 24, 10, 10))
x.shape

torch.Size([3, 24, 10, 10])

> create conv and batch norm layers.

In [3]:
conv1 = nn.Conv2d(24, 24, kernel_size=(3,3), padding=1, bias=False)
conv2 = nn.Conv2d(24, 24, kernel_size=(1,1), bias=False)
conv3 = nn.Conv2d(24, 24, kernel_size=(3,3), stride=(1, 1), padding=1, bias=False)
conv3.weight.data.fill_(0)
for i in range(conv3.weight.data.shape[0]):
    conv3.weight.data[i, i, 1, 1] = 1
print("Done")

Done


In [4]:
bn1 = nn.BatchNorm2d(24)
bn2 = nn.BatchNorm2d(24)
bn3 = nn.BatchNorm2d(24)

### Conv3 is just an identity matrix. 

In [5]:
with torch.no_grad():
    y = conv3(x)
    print(y.shape)
    print(torch.allclose(y, x))

torch.Size([3, 24, 10, 10])
True


> Perform the computation

In [6]:
with torch.no_grad():
    y1 = bn1(conv1(x))
    y2 = bn2(conv2(x))
    y3 = bn3(conv3(x))
    y_unfused = y1+y2+y3
    print(y_unfused.shape)

torch.Size([3, 24, 10, 10])


[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


### Fusing mutli-paths into single conv layer. 
Here we need to do two kinds of fusing.
- Fuse conv layer and batchnorm 
- fuse parallel conv layer. 

Mathematically this is what happens when we apply both the transforms one after the other.

$$
y' = I*W  
$$

$$
y'' = BN(y')
$$

where, 
$$
BN(y') = (y' - \mu) * (\gamma/\sigma) +  \beta
$$

we can rewrite this as 
$$
out = I* (W)(\gamma/\sigma) + ( - (\mu \gamma/\sigma) + \beta )
$$

we will keep 
$$
bias = - (\mu \gamma/\sigma) + \beta
$$

then new `weight` 
$$
W' = (W)(\gamma/\sigma)
$$

so finally we have 

$$
out = I*W' + bias
$$

Lets do this using pytorch

### Fusing conv layer + batchnorm 
here we will fuse conv1 + bn1 to `conv1_fused`. Function is copied from [here](https://gist.githubusercontent.com/FrancescoSaverioZuppichini/42056dee938e5c694d5ea3caca64833f/raw/ad0af226916720dc3baa7a61cc52bf91fbbc8bee/repvgg-5.py)

In [17]:
def get_fused_bn_to_conv_state_dict(
    conv: nn.Conv2d, bn: nn.BatchNorm2d
):
    # in the paper, weights is gamma and bias is beta
    bn_mean, bn_var, bn_gamma, bn_beta = (
        bn.running_mean,
        bn.running_var,
        bn.weight,
        bn.bias,
    )
    # we need the std!
    bn_std = (bn_var + bn.eps).sqrt()
    # eq (3)
    conv_weight = nn.Parameter((bn_gamma / bn_std).reshape(-1, 1, 1, 1) * conv.weight)
    # still eq (3)
    conv_bias = nn.Parameter(bn_beta - bn_mean * bn_gamma / bn_std)
    return {"weight": conv_weight, "bias": conv_bias}

In [12]:
conv1_bn = nn.Sequential(
    conv1,
    bn1)

In [18]:
with torch.no_grad():
    # be sure to switch to eval mode!!
    conv1_bn = conv1_bn.eval()
    
    # create a fused layer. 
    conv1_fused = nn.Conv2d(24, 24, kernel_size=(3,3), padding=1, bias=True)
    conv1_fused.load_state_dict(get_fused_bn_to_conv_state_dict(conv1_bn[0], conv1_bn[1]))
    
    print(torch.allclose(conv1_fused(x), conv1_bn(x), atol=1e-5))
    

True


### Converting a 1x1 conv layer to 3x3 conv layer and getting the same output.

In [25]:
conv2_3x3 = nn.Conv2d(24, 24, kernel_size=(3,3), bias=False, padding=1)
conv2_3x3.load_state_dict({"weight": torch.nn.functional.pad(conv2.weight.data, [1, 1, 1, 1])})
print("Loaded weight")

Loaded weight


In [26]:
with torch.no_grad():
    y = conv2(x)
    print(torch.allclose(y, conv2_3x3(x), atol=1e-5))

True


## Conclusion
There are three major operations we performed and checked. 
- we can fuse Conv layer + BatchNorm into single layer. 
- we can convert a 1x1 conv layer to 3x3 conv layer with padding=1.
- We can create an identity layer with 3x3 conv layer. 

So if we perform the computation in `ResNet` wise or `VGG` style, we will get the same output. For the same reason, we can train the network in `ResNet` style and do inference in `VGG` style. 

Note: padding and bias are set according to output needs.