In [None]:
class DepthwiseSeparableConv2D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding = 0, stride = 1, bias = False, auto_padding = False):
        super(DepthwiseSeparableConv2D, self).__init__()
        padding = check_AutoPadding(kernel_size, auto_padding)
            
        depthwise = nn.Conv2d(in_channels, in_channels, 
                              kernel_size = kernel_size,  padding = padding, stride = stride, 
                              groups = in_channels, bias = bias)
        pointwise = nn.Conv2d(in_channels, out_channels, kernel_size = 1, bias = bias)
        
        self.depthwise_separable_convolution = nn.Sequential(depthwise,
                                                             pointwise)
        
    def forward(self, X):
        return self.depthwise_separable_convolution(X)
    
    
class MiddleFlow(nn.Module):
    def __init__(self, num_maps, kernel_size):
        super(MiddleFlow, self).__init__()
        self.flow = nn.Sequential(
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, auto_padding= True),
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, auto_padding= True),
            nn.ReLU(),
            DepthwiseSeparableConv2D(num_maps, num_maps, kernel_size, auto_padding= True)
        )
    
    def forward(self, X):
        return X + self.flow(X)
    
Activations = nn.ModuleDict([
    ['relu', nn.ReLU()],
    ['identity', nn.Identity()]
])
    
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride = 1, padding = 0, activation = None):
        super(ConvBlock, self).__init__()
        
        padding = AutoPadding(kernel_size) if use_autoPadding else padding
        
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.activation = Activations[activation] if activation != None else Activations['identity']
        
    def forward(self, X):
        return self.activation(self.batch_norm(self.conv(X)))
    
class ResidualDepthwiseSepBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualDepthwiseSepBlock, self).__init__()
        self.flow = nn.Sequential(
            nn.ReLU(),
            DepthwiseSeparableConv2D(in_channels, out_channels, 3)
            nn.ReLU(),
            DepthwiseSeparableConv2D(out_channels, out_channels, 3),
            nn.MaxPool2d(3, stride = 2, padding = 3)
        )
        self.residue = nn.Conv2d(in_channels, out_channels, stride = 2)
        
    def forward(self, X):
        return self.flow(X) + self.residue(X)

In [None]:
class Xception(nn.Module):
    def __init__(self, num_classes = 10):
        super(Xception, self).__init__()
        
        self.entry_flow = nn.Sequential(
            ConvBlock(in_channels = 3, out_channels = 32, kernel_size = 3, stride = 2, padding = 0, activation = 'relu'),
            ConvBlock(in_channels = 32, out_channels = 64, kernel_size = 3, stride = 1, padding = 0),
            ResidualDepthwiseSepBlock(64, 128),
            ResidualDepthwiseSepBlock(128, 256),
            ResidualDepthwiseSepBlock(256, 728)
        )
        
        self.middle_flow = nn.Sequential(
            *[MiddleFlow(768, 3) for n in range(8)]
        )
        
        self.exit_flow = nn.Sequential(
            ResidualDepthwiseSepBlock(768, 1024),
            DepthwiseSeparableConv2D(1024, 1536, 3),
            nn.ReLU(),
            DepthwiseSeparableConv2D(1536, 2048, 3),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )
        
        self.classifier = nn.Linear(2048, num_classes)
        
    def forward(self, X):
        X = self.entry_flow(X)
        X = self.middle_flow(X)
        X = self.exit_flow(X)
        X = self.classifier(X)
        return X