In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import Model, layers
import tensorflow.keras as keras
import glob
from PIL import Image, ImageFilter

import io


In [None]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [None]:
filenames = glob.glob('../data/train/*/*')

In [None]:
TRAIN_LENGTH = 2500
VAL_LENGTH = 500
IMAGE_SIZE = (128, 128)
INPUT_SHAPE = (3000, 128, 128, 3)

In [None]:
import random
random.shuffle(filenames)

In [None]:
for i in range(10):
    x = Image.open(filenames[i])
    x = x.resize(IMAGE_SIZE)
    plt.imshow(x)
    plt.title(filenames[i])
    plt.show()

# Layers


In [None]:
class FeatureEncoder(layers.Layer):
    def __init__(self, out_channels, kernel_size):
        super(FeatureEncoder, self).__init__()
        self.conv = keras.Sequential([
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu')
                ])
        
        self.fe_down = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, strides=2, padding='same',
                                            activation='relu')
        
    def call(self, x):
        x = self.conv(x)
        f = x
        x = self.fe_down(x)
        return f, x
    
class FeatureDecoder(layers.Layer):
    def __init__(self, out_channels, kernel_size):
        super(FeatureDecoder, self).__init__()
        self.de_up = layers.Conv2DTranspose(filters=out_channels, kernel_size=kernel_size, strides=2, 
                                             padding='same', output_padding=1)
        
        self.conv_first = layers.Conv2D(filters=out_channels, kernel_size=1, padding='same', activation='relu')
        self.conv = keras.Sequential([
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu')
            ])
        self.conv_last = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same')

    def call(self, x, down_tensor):
        x = self.de_up(x)
        
        # Calculate cropping for down_tensor to concatenate with x
        _, h2, w2, _ = down_tensor.shape
        _, h1, w1, _ = x.shape
        h_diff, w_diff = h2 - h1, w2 - w1
        
        cropping = ((int(np.ceil(h_diff / 2)), int(np.floor(h_diff / 2))),
                    (int(np.ceil(w_diff / 2)), int(np.floor(w_diff / 2))))
        down_tensor = layers.Cropping2D(cropping=cropping)(down_tensor)        
        x = layers.concatenate([x, down_tensor], axis=3)
        
        x = self.conv_first(x)
        x = self.conv(x)
        x = self.conv_last(x)
        return x

# Model

In [None]:
class DPDNN(Model):
    def __init__(self):
        super(DPDNN, self).__init__()
        self.fe1 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe2 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe3 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe4 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe_end = layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')
        
        self.de4 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de3 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de2 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de1 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de_end = layers.Conv2D(filters=1, kernel_size=3, padding='same')
        
        # Defining learnable parameters
        self.delta_1 = tf.Variable(0.1, trainable=True)
        self.eta_1 = tf.Variable(0.9, trainable=True)
        
        self.delta_2 = tf.Variable(0.1, trainable=True)
        self.eta_2 = tf.Variable(0.9, trainable=True)
        
        self.delta_3 = tf.Variable(0.1, trainable=True)
        self.eta_3 = tf.Variable(0.9, trainable=True)
        
        self.delta_4 = tf.Variable(0.1, trainable=True)
        self.eta_4 = tf.Variable(0.9, trainable=True)
        
        self.delta_5 = tf.Variable(0.1, trainable=True)
        self.eta_5 = tf.Variable(0.9, trainable=True)
        
        self.delta_6 = tf.Variable(0.1, trainable=True)
        self.eta_6 = tf.Variable(0.9, trainable=True)

    
    def call(self, x):
        y = x
        
        for i in range(6):
            f1, out = self.fe1(x)
            f2, out = self.fe2(out)
            f3, out = self.fe3(out)
            f4, out = self.fe4(out)
            out = self.fe_end(out)

            out = self.de4(out, f4)
            out = self.de3(out, f3)
            out = self.de2(out, f2)
            out = self.de1(out, f1)
            v = self.de_end(out)

            v = v + x
            x = self.reconnect(v, x, y, i)
            
        return x
    
    def reconnect(self, v, x, y, i):
        i = i + 1
        if i == 1:
            delta = self.delta_1
            eta = self.eta_1
        if i == 2:
            delta = self.delta_2
            eta = self.eta_2
        if i == 3:
            delta = self.delta_3
            eta = self.eta_3
        if i == 4:
            delta = self.delta_4
            eta = self.eta_4
        if i == 5:
            delta = self.delta_5
            eta = self.eta_5
        if i == 6:
            delta = self.delta_6
            eta = self.eta_6
        
        recon = tf.multiply((1 - delta - eta), v) + tf.multiply(eta, x) + tf.multiply(delta, y)
        return recon

# Dataset

In [None]:
def _generate_pair(filename):
    im = Image.open(filename)
    im = im.resize(IMAGE_SIZE)
    y = np.array(im)
    y_min, y_max = np.min(y), np.max(y)
    y_normalized = (y - y_min) / (y_max - y_min)

    blur_radius = np.random.uniform(1, 5)
    blurred_im = im.filter(ImageFilter.GaussianBlur(blur_radius))
    x = np.array(blurred_im)
    x_min, x_max = np.min(x), np.max(x)
    x_normalized = (x - x_min) / (x_max - x_min)
    
    return (x_normalized, y_normalized)
    

In [None]:
def create_datasets(filenames):
    random.shuffle(filenames)
    dataset = list(map(_generate_pair, filenames))
    x_train, y_train = zip(*dataset)
    return np.array(x_train), np.array(y_train)
    

In [None]:
x_train, y_train = create_datasets(filenames)

# Logging

In [None]:
def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def plot_image_tensorboard(epoch, logs):
    # Create a plot to visualize image reconstruction progress

    # Call the model to get prediction
    pred = model.predict(val_sample[0][0][np.newaxis])

    # Create a mpl figure
    figure = plt.figure(figsize=(10,10))

    # Plot the prediction
    plt.subplot(1, 2, 1)
    plt.title("prediction")
    plt.imshow(pred[0].astype(np.float32))
    # Plot groundtruth
    plt.subplot(1, 2, 2)
    plt.title("ground truth")
    plt.imshow(val_sample[1][0].numpy().astype(np.float32))
    plot_image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("Prediction vs Ground Truth", plot_image, step=epoch)

def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image

def plot_image_tensorboard(epoch, logs):
    # Create a plot to visualize image reconstruction progress
    
    # Call the model to get prediction
    pred = model.predict(val_sample[0][np.newaxis, ...])

    # Create a mpl figure
    figure = plt.figure(figsize=(10,10))

    # Plot the prediction
    plt.subplot(1, 2, 1)
    plt.title("prediction")
    plt.imshow(pred[0].astype(np.float32))
    # Plot groundtruth
    plt.subplot(1, 2, 2)
    plt.title("ground truth")
    plt.imshow(val_sample[1].astype(np.float32))
    plot_image = plot_to_image(figure)
    with file_writer.as_default():
        tf.summary.image("Prediction vs Ground Truth", plot_image, step=epoch)

import datetime
log_dir=os.path.join('logs', 'fit') + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1, profile_batch=1000000)
file_writer = tf.summary.create_file_writer(log_dir)

plot_image_tensorboard_cb = keras.callbacks.LambdaCallback(on_epoch_end=plot_image_tensorboard)

# Training

In [None]:
model = DPDNN()

In [None]:
EPOCHS = 100
BATCH_SIZE = 8
STEPS_PER_EPOCH = 3000 // BATCH_SIZE
VAL_STEPS = VAL_LENGTH // BATCH_SIZE

In [None]:
adam = tf.keras.optimizers.Adam(learning_rate=5e-4, beta_1=0.9, beta_2=0.999, amsgrad=False)
model.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
model.build(INPUT_SHAPE)
model.summary()

In [None]:
val_sample = (x_train[0], y_train[0])

In [None]:
model.fit(x_train, y_train, epochs=400, batch_size=8, validation_split=0.1, 
          callbacks=[plot_image_tensorboard_cb, tensorboard_callback])

In [None]:
i = 3
a = model.predict(x_train[np.newaxis, i, ...])

In [None]:
plt.imshow(y_train[i])

In [None]:
plt.imshow(x_train[i])

In [None]:
plt.imshow(np.squeeze(a))