<a href="https://colab.research.google.com/github/omar178/resenet-implementation/blob/main/ResNet_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
from torch import Tensor
import torch.nn as nn

In [9]:
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)

class Block(nn.Module):
    def __init__(self,nums_layers,in_channels,out_channels,identity=None,stride=1):
          super(Block, self).__init__()
          if nums_layers > 34:
              self.expansion = 4
              self.conv1 = conv1x1(in_channels,out_channels,stride)
          else:
              self.expansion = 1
              self.conv1 = conv3x3(in_channels,out_channels,stride)
          self.bn = nn.BatchNorm2d(out_channels)
          self.relu = nn.ReLU(inplace=True)
          self.conv2 = conv3x3(out_channels,out_channels,stride)
          self.bn1 = nn.BatchNorm2d(out_channels)
          self.identity = identity
    def forward(self,x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)
        if self.identity:
            self.identity = conv1x1(x)
        out+= self.identity 
        out = self.relu(out)
        return out 


class ResNet(nn.Module):
    def __init__(self, num_layers, block,input_channels, num_classes):
        assert num_layers in [18, 34, 50, 101, 152], f'ResNet{num_layers}: Unknown architecture! Number of layers has ' \
                                                     f'to be 18, 34, 50, 101, or 152 '
        super(ResNet, self).__init__()
        self.infeature_maps = 64 # number of feature maps
        self.dilation = 1
        # input channels colored images 3 channels
        self.input_channels = input_channels
        self.num_classes = num_classes
        if num_layers == 18:
          # number of layers in the 4 groups 
            layers = [2, 2, 2, 2]
        elif num_layers == 34 or num_layers == 50:
            layers = [3, 4, 6, 3]
        elif num_layers == 101:
            layers = [3, 4, 23, 3]
        else:
            layers = [3, 8, 36, 3]        

        self.conv1 = nn.Conv2d(in_channels = self.input_channels,out_channels=self.infeature_maps ,kernel_size=7,stride=2,padding=3,dilation=self.dilation)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.group1 = self._make_layer(block, 64, layers[0])
        self.group2 = self._make_layer(block,128,layer[1])
        self.group3 = self._make_layer(block,256,layer[2])
        self.group4 = self._make_layer(block,512,layer[3])
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
    def _make_layer(block,outplanes,n_blocks,stride = 1):
        downsample = None
        out_channels = outplanes
        norm_layer = self._norm_layer
        downsample = None
        previous_dilation = self.dilation
        if dilate:
            self.dilation *= stride
            stride = 1
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                conv1x1(self.inplanes, planes * block.expansion, stride),
                norm_layer(planes * block.expansion),
            )

        layers = []
        for i in range(n_blocks):
            layers.append(block(nums_layers= nums_layers,in_channels = self.input_channels,out_channels = out_channels,downsample = downsample,stride = stride))
        return nn.Sequential(*layers)
        
    def forward(self,x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x


            





