<a href="https://colab.research.google.com/github/toanpt74/COLAB_RD/blob/main/VAE_Train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import random
import datetime
import matplotlib.pyplot as plt
from IPython import display
import numpy as np
# set a random seed
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
AUTO = tf.data.experimental.AUTOTUNE
# parameters for building the model and training

"""
Defining Functions
"""

BATCH_SIZE = 128
LATENT_DIM = 128
COL = 496 #width
ROW = 64 #height
def get_dataset(image_dir):
    image_file_list = os.listdir(image_dir)
    image_paths = [os.path.join(image_dir, fname) for fname in image_file_list]
    random.shuffle(image_paths)
    train_data = tf.data.Dataset.from_tensor_slices((image_paths))
    print(train_data)
    print("Training dataset: {} images".format(len(image_paths)))
    return train_data, len(image_paths)
#
def pre_image(image_filename):
    img_raw = tf.io.read_file(image_filename)
    image = tf.image.decode_bmp(img_raw)

    image = tf.cast(image, dtype=tf.float32)
    image = tf.image.resize(image, (ROW, COL))
    image = image / 255.0
    image = tf.reshape(image, shape=(COL, ROW, 1,))
    print(image.shape)
    return image


class Sampling(layers.Layer):
    def call(self, inputs):
        mu, sigma = inputs
        batch = tf.shape(mu)[0]
        dim = tf.shape(mu)[1]
        epsilon = keras.backend.random_normal(shape=(batch, dim))
        z = mu + tf.exp(0.5 * sigma) * epsilon
        #z = mu + tf.exp(0.5 * sigma
        return z


def encoder_layers(inputs, latent_dim):
    x = layers.Conv2D(filters=4, kernel_size=3, strides=2, padding="same", activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    #x = layers.Conv2D(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    #x = layers.BatchNormalization()(x)

    x = layers.Conv2D(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv2D(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    batch_3 = layers.BatchNormalization()(x)

    x = layers.Flatten()(batch_3)

    mu = layers.Dense(latent_dim, name='latent_mu')(x)
    sigma = layers.Dense(latent_dim, name='latent_sigma')(x)
    return mu, sigma, batch_3.shape


def encoder_model(latent_dim, input_shape):
    inputs = layers.Input(shape=input_shape)
    mu, sigma, conv_shape = encoder_layers(inputs, latent_dim=latent_dim)
    z = Sampling()((mu, sigma))
    model = keras.Model(inputs, outputs=[mu, sigma, z], name='Encoder')
    model.summary()
    keras.utils.plot_model(
        model,
        to_file='encoder.png',
        show_shapes=True,
        show_layer_names=True
    )
    return model, conv_shape


def decoder_layers(inputs, conv_shape):
    units = conv_shape[1] * conv_shape[2] * conv_shape[3]
    x = layers.Dense(units, activation='relu')(inputs)
    x = layers.BatchNormalization()(x)

    x = layers.Reshape((conv_shape[1], conv_shape[2], conv_shape[3]))(x)

    x = layers.Conv2DTranspose(filters=32, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv2DTranspose(filters=16, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)

    #x = layers.Conv2DTranspose(filters=8, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    #x = layers.BatchNormalization()(x)

    x = layers.Conv2DTranspose(filters=4, kernel_size=3, strides=2, padding='same', activation='relu')(x)
    x = layers.BatchNormalization()(x)

    x = layers.Conv2DTranspose(filters=1, kernel_size=3, strides=1, padding='same', activation='sigmoid')(x)
    return x


def decoder_model(latent_dim, conv_shape):
    inputs = layers.Input(shape=(latent_dim,))
    outputs = decoder_layers(inputs, conv_shape)
    model = keras.Model(inputs, outputs, name='Decoder')
    model.summary()
    keras.utils.plot_model(
        model,
        to_file='decoder.png',
        show_shapes=True,
        show_layer_names=True
    )
    return model


def kl_reconstruction_loss(inputs, outputs, mu, sigma):
    kl_loss = 1 + sigma - tf.square(mu) - tf.math.exp(sigma)
    return tf.reduce_mean(kl_loss) * -0.5


def vae_model(encoder, decoder, input_shape):
    inputs = keras.layers.Input(shape=input_shape)
    mu, sigma, z = encoder(inputs)
    reconstructed = decoder(z)
    model = keras.Model(inputs=inputs, outputs=reconstructed)
    loss = kl_reconstruction_loss(inputs, z, mu, sigma)
    model.add_loss(loss)
    return model


def get_models(input_shape, latent_dim):
    encoder, conv_shape = encoder_model(latent_dim=latent_dim, input_shape=input_shape)
    decoder = decoder_model(latent_dim=latent_dim, conv_shape=conv_shape)
    vae = vae_model(encoder, decoder, input_shape=input_shape)
    return encoder, decoder, vae


def generate_and_save_images(model, epoch, step, test_input):
    predictions = model.predict(test_input)
    fig = plt.figure(figsize=(4, 4))
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        img = predictions[i, :, :, 0] * 255
        img = img.astype('int32')
        plt.imshow(img, cmap='gray')
        plt.axis('off')

    fig.suptitle("epoch: {}, step: {}".format(epoch, step))
    plt.savefig('image_at_epoch_{:04d}_step{:04d}.png'.format(epoch, step))
    fig.clear()
    plt.close(fig)
    # plt.show()


"""
VAE model
"""


def Train_VAE(datapath = "", epochs=2000,use_transferlearning = False,model_path = ""):
    # initial_learning_rate = 0.01
    # lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
    #     initial_learning_rate,
    #     decay_steps=5000,
    #     decay_rate=0.96,
    #     staircase=True)

    if use_transferlearning:
        vae = tf.keras.models.load_model(model_path)
    else:
        encoder, decoder, vae = get_models(input_shape=(COL, ROW, 1,), latent_dim=LATENT_DIM)
    optimizer = keras.optimizers.Adam(learning_rate=0.0005)
    loss_metric = keras.metrics.Mean()
    mse_loss = keras.losses.MeanSquaredError()

    '''
    Preparing dataset
    '''
    train_data, no_train = get_dataset(datapath)
    train_ds = (train_data
                .shuffle(no_train)
                .map(pre_image, num_parallel_calls=AUTO)
                .batch(BATCH_SIZE)
                .prefetch(buffer_size=AUTO))

    '''
    Training loop
    '''
    os.system("nvidia-smi")

    # random_vector_for_generation = tf.random.normal(shape=[16, LATENT_DIM])
    # generate_and_save_images(decoder, 0, 0, random_vector_for_generation)

    for epoch in range(epochs):
        print('Start of epoch %d' % (epoch,))
        for step, x_batch_train in enumerate(train_ds):
            with tf.GradientTape() as tape:
                reconstructed = vae(x_batch_train)
                flattened_inputs = tf.reshape(x_batch_train, shape=[-1])
                flattened_outputs = tf.reshape(reconstructed, shape=[-1])
                loss = mse_loss(flattened_inputs, flattened_outputs) * COL * ROW
                loss += sum(vae.losses)

            grads = tape.gradient(loss, vae.trainable_weights)
            optimizer.apply_gradients(zip(grads, vae.trainable_weights))

            loss_metric(loss)
            print('Epoch: %s step: %s mean loss = %s' % (epoch, step, loss_metric.result().numpy()))
        if epoch % 5 == 0:
            vae.save(f'model_vae\\model_{epoch}__{loss}', save_format="tf")


dir = r'E:\ToanPT\1.Code_train_Unet\data\train'

Train_VAE(datapath=dir,epochs=20001,use_transferlearning = False,
          model_path=r"E:\ToanPT\1.Code_train_Unet\models")




print("DONE")
