# AutoEncoderの学習

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.keras import layers, models, callbacks, optimizers, losses

In [None]:
# 画像のサイズ
IMAGE_HEIGHT = IMAGE_WIDTH = 256

# 画像の特徴量次元数
IMAGE_FEATURE_DIM = 16384

# バッチサイズ
BATCH_SIZE = 32

In [None]:
def read_and_preprocess( _image_path ):
    _image = tf.io.read_file( _image_path )
    _image = tf.image.decode_image( _image, channels = 3, expand_animations = False )
    _image = tf.image.resize( _image, (IMAGE_HEIGHT, IMAGE_WIDTH) )
    _image = tf.cast( _image, tf.float32 )
    _image = _image / 255.0

    return _image, _image

In [None]:
with open("bokete_image_annotations.csv", "r", encoding = "utf-8") as f:
    a = f.readlines()

bad_images = set([int(A.strip("\n").strip("\ufeff")) for A in a])
last_annotation_number = max(bad_images)

#
len(bad_images), last_annotation_number

In [None]:
image_paths = image_paths = ["bokete_image/" + IP for IP in os.listdir("bokete_image")]

tmp = []
for IP in image_paths:
    image_number = int(IP.split("/")[-1].split(".")[0])
    if image_number in bad_images or image_number > last_annotation_number: continue
    tmp.append(IP)
image_paths = tmp

#
len(image_paths)

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def generate_dataset( _image_paths ):
    _ds = tf.data.Dataset.from_tensor_slices( (image_paths) )
    _ds = _ds.map( read_and_preprocess, num_parallel_calls = AUTOTUNE )
    _ds = _ds.prefetch( AUTOTUNE )

    return _ds

train_image_paths, test_image_paths = train_test_split(image_paths, test_size = 0.01)
train_dataset = generate_dataset(train_image_paths)
test_dataset = generate_dataset(test_image_paths)

#
len(train_image_paths), len(test_image_paths)

In [None]:
def build_encoder():
    input = layers.Input(shape = (IMAGE_HEIGHT, IMAGE_WIDTH, 3))
    x = layers.Conv2D(filters = 32, kernel_size = 3, strides = 2, padding = "same")(input)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(filters = 64, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(filters = 128, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(filters = 256, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(filters = 512, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Flatten()(x)
    x = layers.Dense(units = IMAGE_FEATURE_DIM)(x)
    x = layers.BatchNormalization()(x)
    output = layers.LeakyReLU()(x)
    
    return models.Model(input, output, name = "encoder")

def build_decoder():
    input = layers.Input(shape = (IMAGE_FEATURE_DIM, ))
    x = layers.Dense(units = np.prod((8, 8, 512)) )(input)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)
    x = layers.Reshape(target_shape = (8, 8, 512))(x)

    x = layers.Conv2DTranspose(filters = 256, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(filters = 128, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(filters = 64, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(filters = 32, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2DTranspose(filters = 16, kernel_size = 3, strides = 2, padding = "same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.LeakyReLU()(x)

    x = layers.Conv2D(filters = 3, kernel_size = 3, strides = 1, padding = "same")(x)
    output = layers.Activation("sigmoid")(x)

    return models.Model(input, output, name = "decoder")

def build_autoencoder(encoder, decoder):
    input = layers.Input(shape = (IMAGE_HEIGHT, IMAGE_WIDTH, 3))
    x = encoder(input)
    output = decoder(x)

    return models.Model(input, output, name = "autoencoder")

encoder = build_encoder()
decoder = build_decoder()
autoencoder = build_autoencoder(encoder, decoder)

#
autoencoder.summary()

In [None]:
class ShowProgress(callbacks.Callback):
    def on_epoch_end(self, epoch, logs = None):
        i = np.random.randint(len(test_image_paths))
        image = read_and_preprocess(test_image_paths[i])[0]

        pred = self.model.predict(np.reshape(image, (1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)))
        pred = np.reshape(pred, (IMAGE_HEIGHT, IMAGE_WIDTH, 3))

        fig = plt.figure(figsize = (10, 20))
        ax = fig.add_subplot(1, 2, 1)
        ax.imshow(image)
        ax.set_title("input")

        ax = fig.add_subplot(1, 2, 2)
        ax.imshow(pred)
        ax.set_title("pred")

        plt.show()

check_point = callbacks.ModelCheckpoint(filepath = './autoencoder.h5', 
                                        monitor = 'val_loss', verbose = 1,
                                        save_best_only = True, save_weights_only = False,
                                        mode='min', period=1)

early_stopping = callbacks.EarlyStopping(monitor='val_loss', patience = 5, verbose = 1, mode = 'min')

In [None]:
optimizer = optimizers.Adam(lr = 0.0001)
loss = losses.MeanSquaredError()

autoencoder.compile(loss = loss, optimizer = optimizer)
autoencoder.fit(train_dataset.batch(BATCH_SIZE), epochs = 150, validation_data = test_dataset.batch(BATCH_SIZE), 
                callbacks = [ShowProgress(), check_point, early_stopping])