# Model

In [1]:
import torch

if torch.cuda.is_available():
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

import torchvision

import torchvision.transforms as transforms

from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader

import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF

import torch.optim as optim

## ConvNet

In [2]:
class ConvNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=2, batch_size=1):
        super().__init__()
        
        self.switches = []
        
        ## Layer 1
        self.layer1 = nn.Conv2d(in_channels=in_channels, out_channels=96, kernel_size=7, stride=2, padding=1)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2, padding = 1, return_indices=True)
        self.norm1 = nn.LocalResponseNorm(96)
        
        ## Layer 2
        self.layer2 = nn.Conv2d(in_channels=96, out_channels=256, kernel_size=5, stride=2, padding=0)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2, padding = 1, return_indices=True)
        self.norm2 = nn.LocalResponseNorm(256)
        
        ## Layer 3
        self.layer3 = nn.Conv2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)

        ## Layer 4
        self.layer4 = nn.Conv2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        
        ## Layer 5
        self.layer5 = nn.Conv2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)
        self.pool5 = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, return_indices=True)
        
        ## Layer 6
        self.layer6 = nn.Linear(9216, 4096)
        self.drop6 =  nn.Dropout(p=0.5)
        
        ## Layer 7
        self.layer7 = nn.Linear(4096, 4096)
        self.drop7 =  nn.Dropout(p=0.5)
        
        ## Output
        self.output = nn.Linear(4096, out_channels)
        
        
        ## Initialize weights
        self.apply(self._init_weights) 
    
    def _init_weights(self, module):
        if isinstance(module, nn.Conv2d):
            module.bias.data.zero_()
            module.weight.data.fill_(1e-2)
        
        elif isinstance(module, nn.Linear):
            module.bias.data.zero_()
            module.weight.data.fill_(1e-2)
        
    def forward(self, x):
        
        ## Layer - 1
        x = self.layer1(x)
        x = F.relu(x)
        
        x, indices1 = self.pool1(x)
        x = self.norm1(x)
        
        self.switches.append(indices1)
   

        ## Layer - 2
        x =  self.layer2(x)
        x = F.relu(x)
        
        x, indices2 = self.pool2(x)
        x = self.norm2(x)
        
        self.switches.append(indices2)

        
        ## Layer - 3
        x = self.layer3(x)
        x = F.relu(x)
       
    
        ## Layer - 4
        x = self.layer4(x)
        x = F.relu(x)
        
        
        ## Layer - 5
        x = self.layer5(x)
        x = F.relu(x)
        
        x, indices5 = self.pool5(x)
        
        self.switches.append(indices5)
       
    
        ## Flatten the tensor
        x = torch.flatten(x, 1)
    
    
        ## Layer - 6
        x = self.layer6(x)
        x = self.drop6(x)
        x = F.relu(x)
        
        
        ## Layer - 7
        x = self.layer7(x)
        x = self.drop7(x)
        x = F.relu(x)
        
        
        ## Output Layer
        x = self.output(x)
        x = F.softmax(x, dim=1)
        
        return x

## DeConvNet

In [3]:
class DeConvNet(nn.Module):
    
    def __init__(self, model: ConvNet, in_channels=3):
        super().__init__()
        
        self.switches = model.switches

        ## Layer 1
        self.deconv1 = nn.ConvTranspose2d(in_channels=96, out_channels=in_channels, kernel_size=7, stride=2, padding=1)
        self.unpool1 = nn.MaxUnpool2d(kernel_size=3, stride=2, padding = 1)
        
        ## Layer 2
        self.deconv2 = nn.ConvTranspose2d(in_channels=256, out_channels=96, kernel_size=5, stride=2, padding=0)
        self.unpool2 = nn.MaxUnpool2d(kernel_size=3, stride=2, padding = 1)
        
        ## Layer 3
        self.deconv3 = nn.ConvTranspose2d(in_channels=384, out_channels=256, kernel_size=3, stride=1, padding=1)

        ## Layer 4
        self.deconv4 = nn.ConvTranspose2d(in_channels=384, out_channels=384, kernel_size=3, stride=1, padding=1)
        
        ## Layer 5
        self.deconv5 = nn.ConvTranspose2d(in_channels=256, out_channels=384, kernel_size=3, stride=1, padding=1)
        self.unpool5 = nn.MaxUnpool2d(kernel_size=3, stride=2, padding=0)
        
        ## Initialize weights
        self._init_weights(model)
    
    def _init_weights(self, model):
    
        for name, param in model.named_parameters():
            if "layer" in name:
                layer_id = int(name[5])
                if layer_id > 5:
                    continue
    
                if "weight" in name:
                    eval(f"self.deconv{layer_id}").weight.data = torch.transpose(param.data, 0, 1)
    
                elif "bias" in name:
                    eval(f"self.deconv{layer_id}").bias.data = param.data

    def forward(self, x):
        pass

In [4]:
def test_deconv():
    x = torch.randn((5, 3, 224, 224))
    model = ConvNet(in_channels = 3, out_channels = 2)
    
    y = model(x)

    deconvmodel = DeConvNet(model)

    print((torch.transpose(deconvmodel.deconv1.weight.data, 0, 1) == model.layer1.weight.data).all())
    print((torch.transpose(deconvmodel.deconv2.weight.data, 0, 1) == model.layer2.weight.data).all())
    print((torch.transpose(deconvmodel.deconv3.weight.data, 0, 1) == model.layer3.weight.data).all())
    print((torch.transpose(deconvmodel.deconv4.weight.data, 0, 1) == model.layer4.weight.data).all())
    print((torch.transpose(deconvmodel.deconv5.weight.data, 0, 1) == model.layer5.weight.data).all())

    print((deconvmodel.deconv1.bias.data == model.layer1.bias.data).all())
    print((deconvmodel.deconv2.bias.data == model.layer2.bias.data).all())
    print((deconvmodel.deconv3.bias.data == model.layer3.bias.data).all())
    print((deconvmodel.deconv4.bias.data == model.layer4.bias.data).all())
    print((deconvmodel.deconv5.bias.data == model.layer5.bias.data).all())


test_deconv()

tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)
tensor(True)


In [5]:
def test():
    x = torch.randn((5, 3, 224, 224))
    model = ConvNet(in_channels = 3, out_channels = 2)
    
    y = model(x)
    print('x.shape:', x.shape)
    print('y.shape:', y.shape)
    print()
    
    print(f'y: {y}')
    
test()

x.shape: torch.Size([5, 3, 224, 224])
y.shape: torch.Size([5, 2])

y: tensor([[0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000],
        [0.5000, 0.5000]], grad_fn=<SoftmaxBackward0>)


## Training

- ImageNet: resize smallest dimenstion to $256$, crop center region as $256 \times 256$
- Subtracting the per-pixel mean
- Then using $10$ different sub-crops of size $224 \times 224$ (corners + center with(out) horizontal flips)
- Multiple different crops and flips of each training sample to boost training set size

**Stochastic Gradient Descent**
- Mini-batch size = $128$
- Learning Rate = $10^{-2}$
- Momentum = $0.9$
- Epochs = $70$

Visualization of $1_{st}$ layer during training reveals that a few of them dominate (Fig. 6(a)). To combat this:
- Normalize each filter in the covolution layer whose $RMS$ value exceeds a fixed radius of $10^{-1}$ to this fixed radius
- This is curcial, especially in the $1_{st}$ layer of the model, where the input images are roughly in the range $[-128, 128]$.