In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.utils import plot_model
from tensorflow.keras.applications.xception import Xception

In [2]:
def depthwise_separable_conv(x, filters, kernel_size=3, strides=1):
    x = layers.DepthwiseConv2D(kernel_size=kernel_size, strides=strides, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(filters, kernel_size=1, strides=1, padding="same", use_bias=False)(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)
    return x

In [9]:
def entry_flow(inputs):
    x = layers.Conv2D(32, (3, 3), strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(64, (3, 3), padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    residual = layers.Conv2D(128, (1, 1), strides=2, padding="same")(x)
    residual = layers.BatchNormalization()(residual)

    x = depthwise_separable_conv(x, 128)
    x = depthwise_separable_conv(x, 128)
    x = layers.MaxPooling2D((3, 3), strides=2, padding="same")(x)

    x = layers.Add()([x, residual])

    residual = layers.Conv2D(256, (1, 1), strides=2, padding="same")(x)
    residual = layers.BatchNormalization()(residual)

    x = depthwise_separable_conv(x, 256)
    x = depthwise_separable_conv(x, 256)
    x = layers.MaxPooling2D((3, 3), strides=2, padding="same")(x)

    x = layers.Add()([x, residual])

    residual = layers.Conv2D(728, (1, 1), strides=2, padding="same")(x)
    residual = layers.BatchNormalization()(residual)

    x = depthwise_separable_conv(x, 728)
    x = depthwise_separable_conv(x, 728)
    x = layers.MaxPooling2D((3, 3), strides=2, padding="same")(x)

    x = layers.Add()([x, residual])
    return x

In [10]:
def middle_flow(x):
    for _ in range(8):
        residual = x

        x = depthwise_separable_conv(x, 728)
        x = depthwise_separable_conv(x, 728)
        x = depthwise_separable_conv(x, 728)

        x = layers.Add()([x, residual])

    return x

In [11]:
def exit_flow(x):
    residual = layers.Conv2D(1024, (1, 1), strides=2, padding='same')(x)
    residual = layers.BatchNormalization()(residual)
    
    x = depthwise_separable_conv(x, 728)
    x = depthwise_separable_conv(x, 1024)
    x = layers.MaxPooling2D((3, 3), strides=2, padding='same')(x)
    
    x = layers.Add()([x, residual])
    
    x = depthwise_separable_conv(x, 1536)
    x = depthwise_separable_conv(x, 2048)
    
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dense(1000, activation='softmax')(x)
    
    return x

In [12]:
def build_xception(input_shape=(299, 299, 3)):
    inputs = layers.Input(shape=input_shape)
    
    x = entry_flow(inputs)
    x = middle_flow(x)
    x = exit_flow(x)
    
    model = models.Model(inputs, x)
    return model

In [None]:
xception_scratch_model = build_xception()
xception_scratch_model.summary()

In [None]:
xception_tf_model = Xception(include_top=True, weights="imagenet", input_shape=(299, 299, 3))
xception_tf_model.summary()