In [None]:
"""
From scratch implementation of the famous ResNet models.
"""

import torch
import torch.nn as nn

class block(nn.Module):
  def __init__(
      self, in_channels,  intermediate_channels, identity_downsample=None, stride=1
      ):
    super().__init__()
    self.expansion = 4
    self.conv1 = nn.Conv2d(
                            in_channels,
                            intermediate_channels,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=False 
                            )                 # for example if 
    self.bn1 = nn.BatchNorm2d(intermediate_channels)

    self.conv2 = nn.Conv2d(
                            intermediate_channels,
                            intermediate_channels,
                            kernel_size=3,
                            stride=stride,
                            padding=1,
                            bias=False 
                            )
    self.bn2 = nn.BatchNorm2d(intermediate_channels)

    self.conv3 = nn.Conv2d(
                            intermediate_channels,
                            intermediate_channels * self.expansion,
                            kernel_size=1,
                            stride=1,
                            padding=0,
                            bias=False 
                            )
    self.bn3 = nn.BatchNorm2d(intermediate_channels * self.expansion)

    self.relu = nn.ReLU()
    self.identity_downsample = identity_downsample
    self.stride = stride

    def forward(self, x):
      identity = x.clone()  # deep copy

      x = self.conv1(x)
      x = self.bn1(x)
      x = self.relu(x)
      x = self.conv2(x)
      x = self.bn2(x)
      x = self.relu(x)
      x = self.conv3(x)
      x = self.bn3(x)

      if self.identity_downsample is not None:
        identity = self.identity_downsample(identity)

      x += identity
      x = self.relu(x)
      return x





class ResNet(nn.Module):
  def __init__(self, block, layers, image_channels, num_classes):   # lyers ->for example [3, 4, 6, 3]
    super(ResNet, self).__init__()
    self.in_channels = 64

    self.conv1 = nn.Conv2d(
        image_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
    )
    self.bn1 = nn.BatchNorm2d(64)
    self.relu = nn.ReLU()
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)


**Output size calculation after applying convolution**


0. Input Layer shape = 3 * 224 * 224  -> (color channels, height, width)

1. After applying conv2d with 64 filters of (7*7) stride = 2 and padding = 3:

* Output shape = ((224 + 2*3 - 7) / 2) + 1 = 

2. After applying Max Pooling (3*3) stride = 2 and padding = 1:

* Output shape = ((112 + 2*1 - 3) / 2) + 1 = 

___

3. After applying conv2d with 64 filters of (1*1) stride = 1 and padding = 0:

* Output shape = ((56 + 2*0 - 1) / 1) + 1 = 

4. After applying conv2d with 64 filters of (3*3) stride = 1 and padding = 1:

* Output shape = ((56 + 2*1 - 3) / 1) + 1 =

5. After applying conv2d with 256 (64 * 4) filters of (1*1) stride = 1 and padding = 0:

* Output shape = ((56 + 2*0 - 1) / 1) + 1 =


---

6. After applying conv2d with 128 filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((28 + 2*0 - 1) / 2) + 1 = 

7. After applying conv2d with 128 filters of (3*3) stride = 2 and padding = 1:

* Output shape = ((28 + 2*1 - 3) / 2) + 1 =

8. After applying conv2d with 512 (128 * 4) filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((28 + 2*0 - 1) / 2) + 1 =

___

9. After applying conv2d with 256 filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((14 + 2*0 - 1) / 2) + 1 = 

10. After applying conv2d with 256 filters of (3*3) stride = 2 and padding = 1:

* Output shape = ((14 + 2*1 - 3) / 2) + 1 =

11. After applying conv2d with 1024 (256 * 4) filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((14 + 2*0 - 1) / 2) + 1 =

---


12. After applying conv2d with 512 filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((7 + 2*0 - 1) / 2) + 1 = 

13. After applying conv2d with 512 filters of (3*3) stride = 2 and padding = 1:

* Output shape = ((7 + 2*1 - 3) / 2) + 1 =

14. After applying conv2d with 2048 (512 * 4) filters of (1*1) stride = 2 and padding = 0:

* Output shape = ((7+ 2*0 - 1) / 2) + 1 =


