In [1]:
import os
import re
import numpy as np
from skimage.transform import resize, rescale
from scipy import ndimage, misc
from matplotlib import pyplot
import matplotlib.pyplot as plt
np.random.seed(0)

In [2]:
from tensorflow.keras.layers import Conv2D, Dense, Input, MaxPooling2D, Dropout
from tensorflow.keras.layers import Conv2DTranspose, UpSampling2D, add
from tensorflow.keras.models import Model
from tensorflow.keras import regularizers
import tensorflow as tf
print(tf.__version__)

2.1.0


In [17]:
# Enc schema mapping
input_data = Input(shape=(256, 256, 3))
enc_l1 = Conv2D(32, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(input_data)
enc_l2 = Conv2D(32, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(enc_l1)

# 2x downSample
enc_l3 = MaxPooling2D(padding='same')(enc_l2)
enc_l4 = Conv2D(96, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(enc_l3)
enc_l5 = Conv2D(96, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(enc_l4)
enc_l6 = MaxPooling2D(padding='same')(enc_l5)
enc_l7 = Conv2D(152, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(enc_l6)
encData = Model(input_data, enc_l7)

In [5]:
# Alternatively, MaxPooling2D can be redef for verbosity as :
# def max_pool_2x2(x):
#   return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')

In [8]:
# Similarly, conv2D call can be redef for verbosity as :
# def conv2D(x, W):
#   return tf.nn.conv2d(x, W, strides=[1, 2, 2, 1], padding='SAME')

In [15]:
encData.summary()

Model: "model_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_6 (InputLayer)         [(None, 256, 256, 3)]     0         
_________________________________________________________________
conv2d_17 (Conv2D)           (None, 256, 256, 32)      896       
_________________________________________________________________
conv2d_18 (Conv2D)           (None, 256, 256, 32)      9248      
_________________________________________________________________
max_pooling2d_6 (MaxPooling2 (None, 128, 128, 32)      0         
_________________________________________________________________
conv2d_19 (Conv2D)           (None, 128, 128, 96)      27744     
_________________________________________________________________
conv2d_20 (Conv2D)           (None, 128, 128, 96)      83040     
_________________________________________________________________
max_pooling2d_7 (MaxPooling2 (None, 64, 64, 96)        0   

In [16]:
# Decoder def
# - upsample encData
# - def merger
# - add residual/skip connections
# - reverse add gate from specific sample media of encoder schema

In [26]:
# Decoder schema with leading 96 sampling color depth
dec_l8 = UpSampling2D()(enc_l7)      # 2x upsample
dec_l9 = Conv2D(96, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(dec_l8)
dec_l10 = Conv2D(96, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(dec_l9)
dec_l11 = add([enc_l5, dec_l10])
dec_l12 = UpSampling2D()(dec_l11)
dec_l13 = Conv2D(32, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(dec_l12)
dec_l14 = Conv2D(32, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(dec_l13)
dec_l15 = add([enc_l2, dec_l14])
dec_l16 = Conv2D(3, (3, 3), padding='same', activation='relu', activity_regularizer=regularizers.l1(10e-10))(dec_l15)
decData = Model(input_data, dec_l16)

In [27]:
decData.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_7 (InputLayer)            [(None, 256, 256, 3) 0                                            
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 256, 256, 32) 896         input_7[0][0]                    
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 256, 256, 32) 9248        conv2d_22[0][0]                  
__________________________________________________________________________________________________
max_pooling2d_8 (MaxPooling2D)  (None, 128, 128, 32) 0           conv2d_23[0][0]                  
____________________________________________________________________________________________

In [31]:
decData.compile(optimizer='adam', loss='mean_squared_error')

In [45]:
def train_batches(just_load_dataset=False):
    batches = 256
    batch = 0
    batch_index = 0

    max_batches = -1
    epoch = 10
    images = []
    x_train_n = []
    x_train_down = []
    
    x_train_n2 = []
    x_train_down2 = []

    for root, dirs, files in os.walk("/home/probe/ImageSuperResEnhancer/training_data"):
        for filename in files:
            if re.search(".\(jpeg|jpg|JPEG|png\)$", filename):
                if batch_index == max_batches:
                    return x_train_n2, x_train_down2
                fpath = os.path.join(root, filename)
                image = pyplot.imread(fpath)
                if len(image.shape) > 2:
                    downscaled_img = resize(image, (256, 256))
                    x_train_n.append(downscaled_img)
                    x_train_down.append(rescale(rescale(downscaled_img, 0.5), 2.0))
                    batch += 1
                    if batch == batches:
                        batch_index += 1
                        
                        x_train_n2 = np.array(x_train_n)
                        x_train_down2 = np.array(x_train_down)

                        if just_load_dataset:
                            return x_train_n2, x_train_down2
                        print('Training batch : ', batch_nb, '(', batches, ')')

                        decData.fit(x_train_down2, x_train_n2, epochs=epoch, batch_size=20, shuffle=True, validation_split=0.10)

                        x_train_n = []
                        x_train_down = []

                        batch = 0
    return x_train_n2, x_train_down2

In [47]:
x_train_n, x_train_down = train_batches(just_load_dataset=False)