In [5]:
import torch
import torch.nn as nn

In [1]:
from torchsummary import summary

In [16]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [34]:
class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding = 0, stride = 1, bias = False, auto_padding = False):
        super(DepthwiseSeparableConv2D, self).__init__()
            
        depthwise = nn.Conv2d(in_channels, in_channels, 
                              kernel_size = kernel_size,  padding = padding, stride = stride, 
                              groups = in_channels, bias = bias)
        pointwise = nn.Conv2d(in_channels, out_channels, kernel_size = 1, bias = bias)
        
        self.depthwise_separable_convolution = nn.Sequential(depthwise,
                                                             pointwise)
        
    def forward(self, X):
        return self.depthwise_separable_convolution(X)
    
    
class MiddleFlow(nn.Module):
    def __init__(self, num_maps, kernel_size):
        super(MiddleFlow, self).__init__()
        self.flow = nn.Sequential(
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, padding = 1),
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, padding = 1),
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, padding = 1)
        )
    
    def forward(self, X):
        return X + self.flow(X)
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0):
        super(ConvBlock, self).__init__()
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.activation = nn.ReLU()
        
    def forward(self, X):
        return self.activation(self.conv(X))
    
class ResidualDepthwiseSepBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualDepthwiseSepBlock, self).__init__()
        self.flow = nn.Sequential(
            nn.ReLU(),
            DepthwiseSeparableConv2D(in_channels, out_channels, 3, padding = 1),
            nn.ReLU(),
            DepthwiseSeparableConv2D(out_channels, out_channels, 3, padding = 1),
            nn.MaxPool2d(3, stride = 2, padding = 1)
        )
        self.residue = nn.Conv2d(in_channels, out_channels, kernel_size = 1, stride = 2)
        
    def forward(self, X):
        return self.flow(X) + self.residue(X)

In [43]:
class Xception(nn.Module):
    def __init__(self, num_classes = 10):
        super(Xception, self).__init__()
        
        self.entry_flow = nn.Sequential(
            ConvBlock(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 2, padding = 0),
            ConvBlock(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 0),
            ResidualDepthwiseSepBlock(64, 128),
            ResidualDepthwiseSepBlock(128, 256),
            ResidualDepthwiseSepBlock(256, 768)
        )
        
        self.middle_flow = nn.Sequential(
            *[MiddleFlow(768, 3) for n in range(8)]
        )
        
        self.exit_flow = nn.Sequential(
            ResidualDepthwiseSepBlock(768, 1024),
            DepthwiseSeparableConv2D(1024, 1536, 3),
            nn.ReLU(),
            DepthwiseSeparableConv2D(1536, 2048, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.classifier = nn.Linear(2048, num_classes)
        
    def forward(self, X):
        X = self.entry_flow(X)
        X = self.middle_flow(X)
        X = self.exit_flow(X)
        X = torch.flatten(X, 1)
        X = self.classifier(X)
        return X
    
model = Xception(10).to(device)

In [44]:
summary(model, (3, 299, 299))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 32, 149, 149]             896
              ReLU-2         [-1, 32, 149, 149]               0
         ConvBlock-3         [-1, 32, 149, 149]               0
            Conv2d-4         [-1, 64, 147, 147]          18,496
              ReLU-5         [-1, 64, 147, 147]               0
         ConvBlock-6         [-1, 64, 147, 147]               0
              ReLU-7         [-1, 64, 147, 147]               0
            Conv2d-8         [-1, 64, 147, 147]             576
            Conv2d-9        [-1, 128, 147, 147]           8,192
DepthwiseSeparableConv2D-10        [-1, 128, 147, 147]               0
             ReLU-11        [-1, 128, 147, 147]               0
           Conv2d-12        [-1, 128, 147, 147]           1,152
           Conv2d-13        [-1, 128, 147, 147]          16,384
DepthwiseSeparableConv2D-14     