In [2]:
import pathlib
import tensorflow as tf
from tensorflow import keras
from keras import layers
from matplotlib import pyplot as plt 
# import cv2
import numpy as np

In [3]:
img_height = 224
img_width = 224
input_shape = (img_width, img_height, 3)
batch_size = 32
train_dir = pathlib.Path('../input/cacao-augmented-224/training_img/training_img/')
test_dir  = pathlib.Path('../input/cacao-augmented-224/testing_img/testing_img/')
checkpoint_dir = pathlib.Path('./cacao_CAE/checkpoint')
model_dir = pathlib.Path('./cacao_CAE/model.h5')


In [4]:
train_ds = keras.utils.image_dataset_from_directory(
  train_dir,
  labels=None,
    seed=252,
  shuffle=True,
    validation_split=0.2,
    subset="training",
  image_size=(img_height, img_width),
  batch_size=batch_size
)
train_ds = train_ds.prefetch(tf.data.AUTOTUNE)


In [145]:
input_layer = layers.Input(shape=input_shape)
output_layer = layers.Rescaling(scale=1./255, offset=0)(input_layer)
model = keras.Model(input_layer, output_layer)

In [182]:
def reset_model():
    input_layer2 = layers.Rescaling(scale=1./255, offset=0)(input_layer)
    # encoder
    x = layers.Conv2D(84, kernel_size=(7,7))(input_layer2)
    x = layers.LeakyReLU(alpha=0.1)(x)

    x = layers.Conv2D(42, kernel_size=(13,13))(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.MaxPooling2D(pool_size=(2,2), strides=2)(x)

    x = layers.Conv2D(42, kernel_size=(9,9))(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.MaxPooling2D(pool_size=(2,2), strides=2)(x)

    x = layers.Conv2D(28, kernel_size=(5,5))(x)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.MaxPooling2D(pool_size=(2,2), strides=2)(x)

    x = layers.Conv2D(28, kernel_size=(3,3))(x)   # (14,14,14)
    x = layers.LeakyReLU(alpha=0.1)(x)
    x = layers.MaxPooling2D(pool_size=(2,2), strides=2)(x)

    x = layers.Conv2D(14, kernel_size=(3,3))(x)   # (7,7,14)
    x = layers.LeakyReLU(alpha=0.1)(x)
#     x = layers.MaxPooling2D(pool_size=(2,2), strides=2)(x)


    x = layers.Flatten()(x)
    x = layers.Dense(686, activation="relu")(x)
    x = layers.Dense(343, activation="relu")(x)
    x = layers.Dense(112, activation="relu")(x)
    # #latent
    latent = layers.Dense(56, activation="softmax")(x)
    # #decoder
    x = layers.Dense(112, activation="relu")(latent)
    x = layers.Dense(343, activation="relu")(x)
    x = layers.Dense(686, activation="relu")(x)
    x = layers.Reshape((7,7,14))(x)

    x = layers.Conv2DTranspose(14, kernel_size=(3,3), padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(28, kernel_size=(3,3), padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(28, kernel_size=(5,5), padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(42, kernel_size=(9,9), padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(42, kernel_size=(13,13), padding="same", strides=2)(x)
    x = layers.Conv2DTranspose(84, kernel_size=(7,7), padding="same", strides=1)(x)

    output_layer = layers.Conv2D(3, kernel_size=(5,5), padding="same")(x)
    output_layer = layers.Rescaling(scale=255, offset=0)(output_layer)
    model = keras.Model(input_layer, output_layer)
#     model.summary()
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["mse"])
    return(model)
    print("Model reset done!")

In [213]:
model = reset_model()
model.summary()

In [206]:
model = keras.models.load_model(model_dir)

In [211]:
epochs = 1
i=0
error=[]
num_batch = 100
for ep in range(epochs):
  print("Epoch",ep+1, "/", epochs)
  for batch in train_ds.take(num_batch):
    i = i+1
    score = model.train_on_batch(x=batch, y=batch, reset_metrics=False, return_dict=False)
    print("batch",i, "--MSE:",score[0])
    error.append(score[0])
    if (np.isnan(score[0])):
        model = reset_model()
        error=[]
        break
    if (i>10):
        if (error[i-1]>error[i-11]) and (error[i-1]>200):
            model = reset_model()
            error=[]
            break
            

# model.save(model_dir)

In [209]:
model.save(model_dir)

In [212]:
for img in train_ds.take(1):
    image = img
    break

result = model.predict(image, batch_size=None)

for i in range(4):
    img = np.asarray(result[i], dtype="uint8")
    plt.subplot(2,2,i+1)
    plt.imshow(img)

