In [None]:
# Mout Google Drive
# https://towardsdatascience.com/google-drive-google-colab-github-dont-just-read-do-it-5554d5824228
from google.colab import drive
ROOT = "/content/drive"
drive.mount(ROOT)
# %pwd %ls
# run github settings
%run /content/drive/MyDrive/CNNStanford/pytorch/pytorch_sandbox/Colab_Helper.ipynb

In [None]:
MESSAGE = "clean file & gitignore again"
!git config --global user.email "ronyginosar@mail.huji.ac.il"
!git config --global user.name "ronyginosar"
!git add .

In [None]:
!git commit -m "{MESSAGE}"
!git push "{GIT_PATH}"

In [None]:
import torch
import torch.nn as nn
import sys
import numpy as np
from torchinfo import summary
import torch.nn.functional as F


#%% Build a simple classification CNN, using custom Module/s
class RB(nn.Module):
    """ initial manual block """
    def __init__(self, in_channels, out_channels, stride=1, padding=1, kernel=3):
        super().__init__()
        # one by one
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding)
        self.conv1_bn = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding)
        self.conv2_bn = nn.BatchNorm2d(out_channels)
        self.residual = None
        if in_channels != out_channels:
            self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        out = self.conv1(x)
        out = self.conv1_bn(out)
        out = F.relu(out) # can move to module and use nn.relu
        out = self.conv2(out)
        out = self.conv2_bn(out)
        if self.residual is not None:
            out += self.residual(x)
        else:
            out += x
        out = F.relu(out)
        return out


#%%
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, padding=1, kernel=3):
        super().__init__()
        # sequential
        self.blocks = nn.Sequential(conv_block(in_channels, out_channels, kernel, stride, padding),
                                    nn.ReLU(),
                                    conv_block(out_channels, out_channels, kernel, stride, padding)
                                    )
        self.residual = None
        if in_channels != out_channels:
            self.residual = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x):
        out = self.blocks(x)
        if self.residual is not None:
            out += self.residual(x)
        else:
            out += x
        out = F.relu(out)
        return out


def conv_block(in_f, out_f, *args, **kwargs):
    return nn.Sequential(nn.Conv2d(in_f, out_f, *args, **kwargs),
                         nn.BatchNorm2d(out_f)
                         )


#%%
class RBPool(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, padding=1, kernel=3, pool='max'):
        super().__init__()
        self.pool = nn.Identity()
        if pool == 'max':
            self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        elif pool == 'avg':
            self.pool = nn.AdaptiveAvgPool2d((1, 1))

        self.rb_pool = nn.Sequential(ResidualBlock(in_channels, out_channels, stride=1, padding=1, kernel=3),
                                     self.pool
                                     )

    def forward(self, x):
        return self.rb_pool(x)


#%%
class Ex1Net(nn.Module):
    def __init__(self, in_channels, out_channels, pools, num_classes):
        super().__init__()
        num_layers = len(in_channels)
        # using ModuleList
        # self.layers = nn.ModuleList([RBPool(in_channels[i], out_channels[i], pools[i]) for i in range(num_layers)])

        # using Sequential and zip
        layers = [RBPool(in_c, out_c, pool=p)
                  for in_c, out_c, p in
                  zip(in_channels, out_channels, pools)]
        self.layers = nn.Sequential(*layers)

        self.linear = nn.Linear(out_channels[-1], num_classes)
        self.linear_in_dim = out_channels[-1]

    def forward(self, x):
        # # for moduleList
        # for layer in self.layers:
        #     x = layer(x)

        # # for sequential
        x = self.layers(x)

        # Hint: need to reshape before applying FC
        x = x.reshape(-1, self.linear_in_dim)  # flat for FC, the size -1 is inferred from other dimensions
        x = self.linear(x)
        return x

#%%
def count_params(net):
    # 'net' can be any nn.Module
    # hint: net.parameters -or- net.named_parameters(), .parameters() is an iterator

    # PyTorch torch.numel() method returns the total number of elements in the input tensor.

    # trainable parameters (remove if for all params)
    # total_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in net.parameters())

    return total_params

#%%
def main(argv):

    # Create a Ex1Net instance
    net = Ex1Net(in_channels=[3, 32, 64, 128],
                 out_channels=[32, 64, 128, 256],
                 pools=['max', 'max', 'max', 'avg'],
                 num_classes=5)

    # - set the network  to 'eval' mode
    net.eval()
    # - generate a random input batch [Nx3x32x32]
    in_images = 10
    x = torch.randn((in_images, 3, 32, 32))

    # - feed the batch through the network (forward), using CPU
    if not torch.cuda.is_available():
        print("running on CPU")
        y = net(x)

    # - feed the batch through the network (forward), using GPU
    #   hint: make sure both input and network are moved to GPU.
    if torch.cuda.is_available():
        print("running on GPU")
        device = torch.device('cuda:0')
        x = x.to(device)
        net = net.to(device)
        y = net(x)

    # - Calculate the number of parameters in the network
    num_params = count_params(net)
    print(f"Total number of parameters: {num_params}")

    # - Make a nice summary table of the network using 'torchinfo'
    # Display the following columns: input size, output size, #params, #MACs per layer
    # ex. from torchinfo import summary ; summary(net, input_size=(1,3,224,224),...)
    # print needed for notebook mode

    print(summary(net,
                  input_size=(in_images, 3, 32, 32),
                  col_names=["input_size", "output_size", "num_params", "mult_adds"],
                  col_width=15,
                  depth=10))


#%%
if __name__ == "__main__":
    main(sys.argv)
