# Imports

In [1]:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import *

# SRCNN architecture

In [2]:
def build_model(input_layer, start_neurons):

  # *** ENCODER ***

  x1 = Conv2D(start_neurons*4, (3, 3), activation='relu', padding='same')(input_img)
  x1 = Conv2D(start_neurons*4, (3, 3), activation='relu', padding='same')(x1)
  x1 = BatchNormalization()(x1)
  pool1 = MaxPooling2D((2, 2))(x1)

  x2 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(pool1)
  x2 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(x2)
  x2 = BatchNormalization()(x2)
  pool2 = MaxPooling2D((2, 2))(x2)

  x3 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(pool2)
  x3 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(x3)
  x3 = BatchNormalization()(x3)
  pool3 = MaxPooling2D((2, 2))(x3)

  # *** MIDDLE ***
  center = Conv2D(start_neurons*16, (3, 3), activation='relu', padding='same')(pool3)
  center = BatchNormalization()(center)

  # *** DECODER ***

  y4 = Conv2DTranspose(start_neurons*8, (3, 3), strides=(2, 2), padding="same")(center)
  y4 = BatchNormalization()(y4)
  y4 = concatenate([y4, x3])
  y4 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(y4)
  y4 = Conv2D(start_neurons*8, (3, 3), activation='relu', padding='same')(y4)
  y4 = BatchNormalization()(y4)

  y3 = Conv2DTranspose(start_neurons*4, (3, 3), strides=(2, 2), padding="same")(y4)
  y3 = BatchNormalization()(y3)
  y3 = concatenate([y3, x2])
  y3 = Conv2D(start_neurons*4, (3, 3), activation='relu', padding='same')(y3)
  y3 = Conv2D(start_neurons*4, (3, 3), activation='relu', padding='same')(y3)
  y3 = BatchNormalization()(y3)

  y2 = Conv2DTranspose(start_neurons*2, (3, 3), strides=(2, 2), padding="same")(y3)
  y2 = Conv2D(start_neurons*2, (3, 3), activation='relu', padding='same')(y2)
  y2 = BatchNormalization()(y2)
  y2 = concatenate([y2, x1])

  y1 = Conv2DTranspose(start_neurons*1, (3, 3), strides=(1, 2), padding="same")(y2)
  y1 = Conv2D(start_neurons*1, (3, 3), activation='relu', padding='same')(y1)
  y1 = BatchNormalization()(y1)

  y0 = Conv2DTranspose(start_neurons*1, (3, 3), strides=(1, 2), padding="same")(y1)
  y0 = Conv2D(start_neurons*1, (3, 3), activation='relu', padding='same')(y0)
  y0 = BatchNormalization()(y0)

  y = Conv2DTranspose(start_neurons*1, (3, 3), strides=(2, 2), padding="same")(y0)
  y = Conv2D(start_neurons*1, (3, 3), activation='relu', padding='same')(y)
  y = BatchNormalization()(y)

  output_layer = Conv2D(1, (1,1), padding="same", activation="sigmoid")(y)

  return output_layer

## Network definition

In [None]:
input_img = Input(shape=(8, 8, 1))  # adapt this if using `channels_first` image data format
mask_img = Input(shape=(16, 64, 1))
out_layer = build_model(input_img, 4)

model_64 = Model(inputs=[input_img, mask_img], outputs = out_layer * mask_img)

model_64.compile(loss='mean_squared_error', optimizer='adam', metrics=['mean_squared_error'])

model_64.summary()