In [1]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, layers
import tensorflow.keras as keras


In [2]:
os.environ["CUDA_VISIBLE_DEVICES"]="0"

# Layers

In [3]:
class FeatureEncoder(layers.Layer):
    def __init__(self, out_channels, kernel_size):
        super(FeatureEncoder, self).__init__()
        self.conv = keras.Sequential([
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu')
                ])
        
        self.fe_down = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, strides=2, padding='same',
                                            activation='relu')
        
    def call(self, x):
        x = self.conv(x)
        f = x
        x = self.fe_down(x)
        return f, x
    
class FeatureDecoder(layers.Layer):
    def __init__(self, out_channels, kernel_size):
        super(FeatureDecoder, self).__init__()
        self.de_up = layers.Conv2DTranspose(filters=out_channels, kernel_size=kernel_size, strides=2, 
                                             padding='same', output_padding=1)
        
        self.conv_first = layers.Conv2D(filters=out_channels, kernel_size=1, padding='same', activation='relu')
        self.conv = keras.Sequential([
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu'),
                layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same', activation='relu')
            ])
        self.conv_last = layers.Conv2D(filters=out_channels, kernel_size=kernel_size, padding='same')

    def call(self, x, down_tensor):
        x = self.de_up(x)
        
        # Calculate cropping for down_tensor to concatenate with x
        _, h2, w2, _ = down_tensor.shape
        _, h1, w1, _ = x.shape
        h_diff, w_diff = h2 - h1, w2 - w1
        
        cropping = ((int(np.ceil(h_diff / 2)), int(np.floor(h_diff / 2))),
                    (int(np.ceil(w_diff / 2)), int(np.floor(w_diff / 2))))
        down_tensor = layers.Cropping2D(cropping=cropping)(down_tensor)        
        x = layers.concatenate([x, down_tensor], axis=3)
        
        x = self.conv_first(x)
        x = self.conv(x)
        x = self.conv_last(x)
        return x

# Model

In [4]:
class DPDNN(Model):
    def __init__(self):
        super(DPDNN, self).__init__()
        self.fe1 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe2 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe3 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe4 = FeatureEncoder(out_channels=64, kernel_size=3)
        self.fe_end = layers.Conv2D(filters=64, kernel_size=3, padding='same', activation='relu')
        
        self.de4 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de3 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de2 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de1 = FeatureDecoder(out_channels=64, kernel_size=3)
        self.de_end = layers.Conv2D(filters=1, kernel_size=3, padding='same')
        
        # Defining learnable parameters
        self.delta_1 = tf.Variable(0.1, trainable=True)
        self.eta_1 = tf.Variable(0.9, trainable=True)
        
        self.delta_2 = tf.Variable(0.1, trainable=True)
        self.eta_2 = tf.Variable(0.9, trainable=True)
        
        self.delta_3 = tf.Variable(0.1, trainable=True)
        self.eta_3 = tf.Variable(0.9, trainable=True)
        
        self.delta_4 = tf.Variable(0.1, trainable=True)
        self.eta_4 = tf.Variable(0.9, trainable=True)
        
        self.delta_5 = tf.Variable(0.1, trainable=True)
        self.eta_5 = tf.Variable(0.9, trainable=True)
        
        self.delta_6 = tf.Variable(0.1, trainable=True)
        self.eta_6 = tf.Variable(0.9, trainable=True)

    
    def call(self, x):
        y = x
        
        for i in range(6):
            f1, out = self.fe1(x)
            f2, out = self.fe2(out)
            f3, out = self.fe3(out)
            f4, out = self.fe4(out)
            out = self.fe_end(out)

            out = self.de4(out, f4)
            out = self.de3(out, f3)
            out = self.de2(out, f2)
            out = self.de1(out, f1)
            v = self.de_end(out)

            v = v + x
            x = self.reconnect(v, x, y, i)
            
        return x
    
    def reconnect(self, v, x, y, i):
        i = i + 1
        if i == 1:
            delta = self.delta_1
            eta = self.eta_1
        if i == 2:
            delta = self.delta_2
            eta = self.eta_2
        if i == 3:
            delta = self.delta_3
            eta = self.eta_3
        if i == 4:
            delta = self.delta_4
            eta = self.eta_4
        if i == 5:
            delta = self.delta_5
            eta = self.eta_5
        if i == 6:
            delta = self.delta_6
            eta = self.eta_6
        
        recon = tf.multiply((1 - delta - eta), v) + tf.multiply(eta, x) + tf.multiply(delta, y)
        return recon

In [5]:
input_shape=(100, 128, 128, 1)

In [6]:
model = DPDNN()

In [7]:
adam = tf.keras.optimizers.Adam(learning_rate=1e-3, beta_1=0.9, beta_2=0.999, amsgrad=False)
model.compile(optimizer=adam, loss='mean_squared_error', metrics=['mean_squared_error'])
model.build(input_shape)
model.summary()

Model: "dpdnn"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
feature_encoder (FeatureEnco multiple                  111424    
_________________________________________________________________
feature_encoder_1 (FeatureEn multiple                  147712    
_________________________________________________________________
feature_encoder_2 (FeatureEn multiple                  147712    
_________________________________________________________________
feature_encoder_3 (FeatureEn multiple                  147712    
_________________________________________________________________
conv2d_16 (Conv2D)           multiple                  36928     
_________________________________________________________________
feature_decoder (FeatureDeco multiple                  192896    
_________________________________________________________________
feature_decoder_1 (FeatureDe multiple                  192896