In [1]:
import tensorflow as tf

In [2]:
import numpy as np

from typing import List

from keras.layers import Input, Conv2D, LeakyReLU, BatchNormalization, Dropout
from keras.layers import Flatten, Dense, Reshape, Conv2DTranspose, Activation

from keras.models import Model
from keras.optimizers import Adam

from keras import backend as K

In [None]:
class Autoencoder():
    def __init__(self,
                 input_dim,
                 encoder_conv_filters: List[int],
                 encoder_conv_kernel_size: List[int],
                 encoder_conv_strides: List[int],
                 decoder_conv_t_filters: List[int],
                 decoder_conv_t_kernel_size: List[int],
                 decoder_conv_t_strides: List[int],
                 z_dim: int,
                 use_batch_norm=False,
                 use_dropout=False,
                ):
        self.name = 'autoencoder'
        self.input_dim = input_dim
        self.encoder_conv_filters = encoder_conv_filters
        self.encoder_conv_kernel_size = encoder_conv_kernel_size
        self.encoder_conv_strides = encoder_conv_strides
        self.decoder_conv_t_filters = decoder_conv_t_filters
        self.decoder_conv_t_kernel_size = decoder_conv_t_kernel_size
        self.decoder_conv_t_strides = decoder_conv_t_strides
        
        self.use_batch_norm = use_batch_norm
        self.use_dropout = use_dropout
        
        self.z_dim = z_dim
        
        self.n_layers_encoder = len(encoder_conv_filters)
        self.n_layers_decoder = len(decoder_conv_t_filters)
        
        self._build()

        
    def _build(self):
        #========
        # Encoder
        #========
        encoder_input = Input(shape=self.input_dim, name='encoder_input')
        x = encoder_input
        
        # Convolution Block
        for i in range(self.n_layers_encoder):
            # Convolution
            conv_layer = Conv2D(filters=self.encoder_conv_filters[i],
                                kernel_size=self.encoder_conv_kernel_size[i],
                                strides=self.encoder_conv_strides[i],
                                padding='same',
                                name=f'encoder_conv_{i}')
            x = conv_layer(x)
            # Activation
            x = LeakyReLU()(x)
            
            # Batch-Normalization
            if self.use_batch_norm:
                x = BatchNormalization()(x)
            # Dropout
            if self.use_dropout:
                x = Dropout(rate=0.25)(x)
        
        shape_before_flattening: tuple = K.int_shape(x)[1:] # shape: (bacth, height, width, channel)
        # Flatten
        x = Flatten()(x)
        # Dense
        encoder_output = Dense(units=self.z_dim, name='encoder_output')(x)
        
        self.encoder: Model = Model(inputs=encoder_input, outputs=encoder_output)
        
        #========
        # Decoder
        #========
        decoder_input = Input(shape=(self.z_dim,), name='decoder_input')
        
        # Dense for shape retrieval
        x = Dense(np.prod(shape_before_flattening))(decoder_input)
        x = Reshape(target_shape=shape_before_flattening)(x)
        
        # Transpose-Convolution Block
        for i in range(self.n_layers_decoder):
            conv_t_layer = Conv2DTranspose(filters=self.decoder_conv_t_filters[i],
                                           kernel_size=self.decoder_conv_t_filters[i],
                                           strides=self.decoder_conv_t_strides[i],
                                           padding='same',
                                           name=f'decoder_conv_t_{i}')
            x = conv_t_layer(x)
            
            # Relu Activation not applied to final output layer
            if i < self.n_layers_decoder - 1:
                x = LeakyReLU()(x)
                
                if self.use_batch_norm:
                    x = BatchNormalization()(x)
                
                if self.use_dropout:
                    x = Dropout(rate=0.25)(x)
            else:
                x = Activation('sigmoid')(x) # image pixel 값 range안에 들도록 하기 위해 0~1 사이로 출력 나오도록 하기 위함
            
        decoder_output = x 

        self.decoder: Model = Model(inputs=decoder_input, outputs=decoder_output)
        
        # Combine encoder and decoder
        model_input = encoder_input
        model_output = self.decoder(encoder_output)
        
        self.model = Model(inputs=model_input, outputs=model_output)
        
    def compile(self, learning_rate):
        self.learning_rate = learning_rate
        
        optimizer = Adam(lr=learning_rate)
        
        # RMSE loss function
        def r_loss(y_true, y_pred):
            return K.mean(K.square(y_true - y_pred), axis=[1,2,3])
        
        self.model.compile(optimizer=optimizer, loss=r_loss)
        
    def save(self, folder):

        if not os.path.exists(folder):
            os.makedirs(folder)
            os.makedirs(os.path.join(folder, 'viz'))
            os.makedirs(os.path.join(folder, 'weights'))
            os.makedirs(os.path.join(folder, 'images'))

        with open(os.path.join(folder, 'params.pkl'), 'wb') as f:
            pickle.dump([
                self.input_dim
                , self.encoder_conv_filters
                , self.encoder_conv_kernel_size
                , self.encoder_conv_strides
                , self.decoder_conv_t_filters
                , self.decoder_conv_t_kernel_size
                , self.decoder_conv_t_strides
                , self.z_dim
                , self.use_batch_norm
                , self.use_dropout
                ], f)

        self.plot_model(folder)
        
    def load_weights(self, filepath):
        self.model.load_weights(filepath)
        
    def plot_model(self, run_folder):
        plot_model(self.model, to_file=os.path.join(run_folder ,'viz/model.png'), show_shapes = True, show_layer_names = True)
        plot_model(self.encoder, to_file=os.path.join(run_folder ,'viz/encoder.png'), show_shapes = True, show_layer_names = True)
        plot_model(self.decoder, to_file=os.path.join(run_folder ,'viz/decoder.png'), show_shapes = True, show_layer_names = True)
        
    def train(self, x_train, batch_size, epochs, run_folder, print_every_n_batches=100, initial_epoch=0, lr_decay=1):

In [None]:
ae = Autoencoder(input_dim=(28,28,1),
                 encoder_conv_filters=[32,64,64,64],
                 encoder_conv_kernel_size=[3,3,3,3],
                 encoder_conv_strides=[1,2,2,1],
                 decoder_conv_t_filters=[64,64,32,1],
                 decoder_conv_t_kernel_size=[3,3,3,3],
                 decoder_conv_t_strides=[1,2,2,1],
                 z_dim=2,
                 use_batch_norm=True,
                 use_dropout=True)