[![deep-learning-notes](https://github.com/semilleroCV/deep-learning-notes/raw/main/assets/banner-notebook.png)](https://github.com/semilleroCV/deep-learning-notes)

# <font color='#4C5FDA'>**Xception from scratch** </font>

In [None]:
#@title **Install required packages**

%%capture
! pip install torchinfo

In [3]:
#@title **Import required libraries**.

# Pytorch essentials
import torch # 2.2.1
import torch.nn as nn
from torchinfo import summary

In [3]:
print(torch.__version__)

2.2.1


**The Xception architecture**: the data first goes through the entry flow, then through the middle flow which is repeated eight times, and finally through the exit flow. Note that all Convolution and SeparableConvolution layers are followed by batch normalization [7] (not included in the diagram). All SeparableConvolution layers use a depth multiplier of 1 (no depth expansion)

In this architecture we first perform the 1x1 convolution and then the 3x3 separable convolution.

<div align="center"> <image src="https://miro.medium.com/v2/resize:fit:720/format:webp/1*J8dborzVBRBupJfvR7YhuA.png" width=600>  </div>

The entire architecture looks like this:

<div align="center"> <image src="https://maelfabien.github.io/assets/images/xception.jpg" width=800>  </div>



#### <font color='#8203b1'>**Entry flow**</font>

In [4]:
class DoubleConvBlock(nn.Module):
  def __init__(self, in_channels: int, out_channels: int):
     super().__init__()
    # Double convolutional block at the beginning of Xception
     self.double_conv = nn.Sequential(
         nn.Conv2d(in_channels, out_channels//2, kernel_size=3, stride=2, bias=False),
         nn.BatchNorm2d(out_channels//2),
         nn.ReLU(inplace=True),
         nn.Conv2d(out_channels//2, out_channels, kernel_size=3, stride=1, bias=False),
         nn.BatchNorm2d(out_channels),
         nn.ReLU(inplace=True)
      )

  def forward(self, x):
    return self.double_conv(x)
  
class SeparableConv2d(nn.Module):
  def __init__(self, in_channels: int, out_channels: int):
    super().__init__()

    self.depth_wise_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, groups=in_channels, padding=1, bias=False)
    self.one_by_one = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)

  def forward(self, x):
    x = self.depth_wise_conv(x)
    x = self.one_by_one(x)
    return x

In [5]:
class XceptionModule(nn.Module):
  def __init__(self, in_channels: int, out_channels: int, relu_at_start=True):
    super().__init__()

    # first one by one
    self.one_by_one = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
      nn.BatchNorm2d(out_channels)
    )

    if relu_at_start:
      self.double_depth_wise_conv = nn.Sequential(
        nn.ReLU(inplace=False),
        SeparableConv2d(in_channels, out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=False),
        SeparableConv2d(out_channels, out_channels),
        nn.BatchNorm2d(out_channels),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
      )
    else:
      self.double_depth_wise_conv = nn.Sequential(
        SeparableConv2d(in_channels, out_channels),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(inplace=False),
        SeparableConv2d(out_channels, out_channels),
        nn.BatchNorm2d(out_channels),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
      )

  def forward(self, x):
    x1 = self.one_by_one(x)
    x2 =self.double_depth_wise_conv(x)
    x = torch.add(x1, x2)
    return x

In [6]:
class EntryFlowModule(nn.Module):
  def __init__(self, in_channels: int, out_channels: int):
    super().__init__()

    # 2d double convolution at start
    self.double_conv = DoubleConvBlock(in_channels, 64)

    self.block1 = XceptionModule(64, 128, relu_at_start=False)
    self.block2 = XceptionModule(128, 256)
    self.block3 = XceptionModule(256, out_channels)
  
  def forward(self, x):
    x = self.double_conv(x)
    x = self.block1(x)
    x = self.block2(x)
    x = self.block3(x)
    return x

In [7]:
# Test the module to see if it gives the expected result.

input_image = torch.rand([2, 1, 299, 299])
print(f"Entrada: {input_image.size(), {input_image.dtype}}")
model = EntryFlowModule(in_channels=1, out_channels=728)
ouput = model(input_image)
print(f"Salida: {ouput.size(), ouput.dtype}")

Entrada: (torch.Size([2, 1, 299, 299]), {torch.float32})
Salida: (torch.Size([2, 728, 19, 19]), torch.float32)


#### <font color='#8203b1'>**Middle flow**</font>

In [8]:
class XceptionMiddleModule(nn.Module):
  def __init__(self, in_channels: int, out_channels: int):
    super().__init__()

    # triple separable conv 2d
    self.triple_depth_wise_conv = nn.Sequential(
      nn.ReLU(inplace=True),
      SeparableConv2d(in_channels, out_channels),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      SeparableConv2d(out_channels, out_channels),
      nn.BatchNorm2d(out_channels),
      nn.ReLU(inplace=True),
      SeparableConv2d(out_channels, out_channels),
      nn.BatchNorm2d(out_channels),
    )
  
  def forward(self, x):
    x = self.triple_depth_wise_conv(x)
    return x

In [9]:
class MiddleFlowModule(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.middle_flow = nn.Sequential()
    for _ in range(8):
      self.middle_flow.append(XceptionMiddleModule(in_channels, out_channels))

  def forward(self, x):
    x = self.middle_flow(x)
    return x


In [10]:
# Test the module to see if it gives the expected result.

input_image = torch.rand([2, 728, 19, 19])
print(f"Entrada: {input_image.size(), {input_image.dtype}}")
model = MiddleFlowModule(in_channels=728, out_channels=728)
ouput = model(input_image)
print(f"Salida: {ouput.size(), ouput.dtype}")

Entrada: (torch.Size([2, 728, 19, 19]), {torch.float32})
Salida: (torch.Size([2, 728, 19, 19]), torch.float32)


#### <font color='#8203b1'>**Exit flow**</font>

In [11]:
class XceptionExitModule(nn.Module):
  def __init__(self, in_channels: int, out_channels: int):
    super().__init__()

    # first one by one
    self.one_by_one = nn.Sequential(
      nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False),
      nn.BatchNorm2d(out_channels)
    )

    self.double_depth_wise_conv = nn.Sequential(
      nn.ReLU(inplace=False),
      SeparableConv2d(in_channels, in_channels),
      nn.BatchNorm2d(in_channels),
      nn.ReLU(inplace=False),
      SeparableConv2d(in_channels, out_channels),
      nn.BatchNorm2d(out_channels),
      nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
    )

  def forward(self, x):
    x1 = self.one_by_one(x)
    x2 =self.double_depth_wise_conv(x)
    x = torch.add(x1, x2)
    return x

In [12]:
class ExitFlowModule(nn.Module):
  def __init__(self, in_channels, n_classes):
    super().__init__()

    self.block1 = XceptionExitModule(in_channels, 1024)
    self.block2 = nn.Sequential(
      SeparableConv2d(1024, 1536),
      nn.BatchNorm2d(1536),
      nn.ReLU(inplace=True),
      SeparableConv2d(1536, 2048),
      nn.BatchNorm2d(2048),
      nn.ReLU(inplace=True),
    )

    self.gap = nn.AdaptiveAvgPool2d((1, 1))
    self.last_fc = nn.Linear(2048, n_classes)
  
  def forward(self, x):
    x = self.block1(x)
    x = self.block2(x)
    x = self.gap(x)
    x = x.view(x.size(0), -1)
    x = self.last_fc(x)
    return x

In [13]:
# Test the module to see if it gives the expected result.

input_image = torch.rand([2, 728, 19, 19])
print(f"Entrada: {input_image.size(), {input_image.dtype}}")
model = ExitFlowModule(in_channels=728, n_classes=1)
ouput = model(input_image)
print(f"Salida: {ouput.size(), ouput.dtype}")

Entrada: (torch.Size([2, 728, 19, 19]), {torch.float32})
Salida: (torch.Size([2, 1]), torch.float32)


#### <font color='#8203b1'>**Full model**</font>

In [14]:
class Xception(nn.Module):
  def __init__(self, n_channels, n_classes):
    super().__init__()
    self.entry_flow = EntryFlowModule(n_channels, 728)
    self.middle_flow = MiddleFlowModule(728, 728)
    self.exit_flow = ExitFlowModule(728, n_classes)

  def forward(self, x):
    x = self.entry_flow(x)
    x = self.middle_flow(x)
    x = self.exit_flow(x)
    return x

In [15]:
# Test the model to see if it gives the expected result.

input_image = torch.rand([2, 1, 299, 299])
print(f"Input: {input_image.size(), input_image.dtype}")
model = Xception(n_channels=1, n_classes=1)
ouput = model(input_image)
print(f"Ouput: {ouput.size(), ouput.dtype}")

Input: (torch.Size([2, 1, 299, 299]), torch.float32)
Ouput: (torch.Size([2, 1]), torch.float32)


In [16]:
print(model)

Xception(
  (entry_flow): EntryFlowModule(
    (double_conv): DoubleConvBlock(
      (double_conv): Sequential(
        (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU(inplace=True)
        (3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
        (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (5): ReLU(inplace=True)
      )
    )
    (block1): XceptionModule(
      (one_by_one): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (double_depth_wise_conv): Sequential(
        (0): SeparableConv2d(
          (depth_wise_conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=64, bias=False)
          (one_by_one): Conv2d(64, 128, ker

In [17]:
summary(model, input_image.size())

Layer (type:depth-idx)                             Output Shape              Param #
Xception                                           [2, 1]                    --
├─EntryFlowModule: 1-1                             [2, 728, 19, 19]          --
│    └─DoubleConvBlock: 2-1                        [2, 64, 147, 147]         --
│    │    └─Sequential: 3-1                        [2, 64, 147, 147]         18,912
│    └─XceptionModule: 2-2                         [2, 128, 74, 74]          --
│    │    └─Sequential: 3-2                        [2, 128, 74, 74]          8,448
│    │    └─Sequential: 3-3                        [2, 128, 74, 74]          26,816
│    └─XceptionModule: 2-3                         [2, 256, 37, 37]          --
│    │    └─Sequential: 3-4                        [2, 256, 37, 37]          33,280
│    │    └─Sequential: 3-5                        [2, 256, 37, 37]          102,784
│    └─XceptionModule: 2-4                         [2, 728, 19, 19]          --
│    │    └─Seq