In [None]:
from Model.DCVAE import DCVAE
from Model.Utils import PlotHistory, PlotDataAE, Save_Model
from Model.Utils import load_from_tfrecords, convert_to_tfrecords, load_numpy
from keras.optimizers import RMSprop,Adam
import numpy as np
from sklearn.model_selection import train_test_split
import tensorflow as tf
import os
%matplotlib notebook

# Load DataSet

In [None]:
load_tfrecords = True 
create_tfrecords = False

path_tfr = 'DataSet/MPS100'
path_npy = '/share_delta/GeoFacies/DataSet/MPS100/MPS100.npy'
batch_size = 128

In [None]:
if load_tfrecords: 
    if create_tfrecords: 
        x_train, x_test = load_numpy(path_npy,random_state=0,split_data=0.30)
        convert_to_tfrecords(path_tfr, x_train, x_test)
        gen_train, gen_test = load_from_tfrecords(path_tfr, batch_size)
    else:
        try:
            gen_train, gen_test = load_from_tfrecords(path_tfr, batch_size)
        except:
            print("Data not found. Change 'create_tfrecords' to True")
else:
    x_train, x_test = load_numpy(path_npy,random_state=0,split_data=0.30)

# Create Convolutional Variational Autoencoder

In [None]:
path_weights = os.path.join(path_tfr, 'w100.hdf5')

model = DCVAE(input_shape=(100, 100, 2),filters=[32,32,16],strides=[2,2,1],hidden_dim=5000,
              KernelDim=(5,5,3),latent_dim=500,opt=Adam(1e-4),dropout=0.1,epochs_drop=200,filepath=path_weights)

# Training Network

In [None]:
num_epochs = 500

if load_tfrecords:
    x_train = gen_train.mps_generator()
    x_val = gen_test.mps_generator()
    model.fit_generator(x_train,
                        num_epochs=num_epochs, verbose=1, 
                        steps_per_epoch = len(gen_train),
                        val_set = x_val,
                        validation_steps = len(gen_test))
else:
    model.fit(x_train, num_epochs=num_epochs, batch_size = batch_size, x_v=x_test,verbose=1) 

# Ploting Train History

In [None]:
PlotHistory(model.history.history,listKeys=['val_loss','loss']) 

In [None]:
PlotHistory(model.history.history,listKeys=['lr'])

In [None]:
PlotHistory(model.history.history,listKeys=['val_acc_pred','acc_pred'])

# Evaluate Model with Test dataset 

In [None]:
if load_tfrecords:
    x_test = gen_test.get_numpy_batch()

x_rec = model.model.predict(x_test)
PlotDataAE(np.argmax(x_test[:,:,:,:],axis=-1),np.argmax(x_rec[:,:,:,:],axis=-1),digit_size=(100,100))

# Generate random sampling

In [None]:
x_gen=model.generate(binary=True)
PlotDataAE([],x_gen[:,:,:],Only_Result=False,digit_size=(100,100))

# Save Encoder and Decoder

In [None]:
Save_Model(model.encoder,'Model/TrainModel/CVAE100_encoder')

In [None]:
Save_Model(model.generator,'Model/TrainModel/CVAE100_decoder')