In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, BatchNormalization, ReLU, GlobalAveragePooling2D, Dense, UpSampling2D
from tensorflow.keras.layers import SeparableConv2D, Input, Flatten, Reshape, Add
from tensorflow.keras.models import Model

def xception_block(inputs, filters, strides=1):
    residual = Conv2D(filters, (1, 1), strides=strides, padding='same')(inputs)
    residual = BatchNormalization()(residual)
    
    x = SeparableConv2D(filters, (3, 3), padding='same', strides=strides)(inputs)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    x = SeparableConv2D(filters, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    x = ReLU()(x)
    
    x = SeparableConv2D(filters, (3, 3), padding='same')(x)
    x = BatchNormalization()(x)
    
    x = Add()([x, residual])
    x = ReLU()(x)
    return x

# Define the middle block
def middle_block(inputs, filters):
    x = inputs
    for _ in range(4):
        x = xception_block(x, filters)
    return x

# Encoder
inputs = Input(shape=(256, 256, 3))

x = Conv2D(64, (3, 3), padding='same')(inputs)
x = BatchNormalization()(x)
x = ReLU()(x)

x = Conv2D(64, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = xception_block(x, 128, strides=2)
x = xception_block(x, 256, strides=2)

x = middle_block(x, 256)

x = xception_block(x, 256)

x = SeparableConv2D(256, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = SeparableConv2D(256, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = GlobalAveragePooling2D()(x)
encoded = Dense(10, activation='softmax', name='encoded')(x)

# Decoder
x = Dense(4*4*256)(encoded)
x = Reshape((4, 4, 256))(x)

x = UpSampling2D((4, 4))(x)
x = Conv2D(128, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = UpSampling2D((3, 4))(x)
x = Conv2D(64, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = UpSampling2D((2, 2))(x)
x = Conv2D(32, (3, 3), padding='same')(x)
x = BatchNormalization()(x)
x = ReLU()(x)

x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), padding='same', activation='sigmoid')(x)

autoencoder = Model(inputs, decoded)

autoencoder.compile(optimizer='adam', loss='mse')

autoencoder.summary()


2024-05-29 10:21:34.007323: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-29 10:21:34.007430: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-29 10:21:34.141111: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
