In [20]:
import speedup
import tensorflow as tf
import numpy as np
import random
random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)

In [None]:
def reshape_tensor(x, heads, flag):
    if flag:
        x = tf.reshape(x, shape=(tf.shape(x)[0], tf.shape(x)[1], heads, -1))
        x = tf.transpose(x, perm=(0, 2, 1, 3))
    else:
        x = tf.reshape(x, shape=(-1, heads, tf.shape(x)[1], tf.shape(x)[2]))
        x = tf.transpose(x, perm=(0, 2, 1, 3))
    return x

In [None]:
#Multi-DConv Head Transposed Self-Attention (MDTA)
def MDTA(x, n_filters, num_heads):
  b,c,h,w = x.shape

  x0 = tf.keras.layers.Normalization()(x)

  q = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(x0)
  k = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(x0)
  v = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(x0)

  q = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(q)
  k = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(k)
  v = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(v)

  q = reshape_tensor(q, num_heads, True)
  k = reshape_tensor(k, num_heads, True)
  v = reshape_tensor(v, num_heads, True)

  #k = tf.transpose(k, [-2, -1])

  #attn = tf.matmul(q, k)
  #attn = tf.keras.activations.softmax(attn)

  #out = tf.matmul(attn, v)

  #out = reshape_tensor(v, num_heads, False)
  #out = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(out)

  #x = x + out
  return q 




In [None]:
#Gate-Dconv Feed-Forward Network (GDFN)
def GDFN(x, n_filters):
  x0 = tf.keras.layers.Normalization()(x)

  x1 = tf.keras.layers.Conv2D(n_filters * 2, (1, 1), padding="same")(x0)
  x2 = tf.keras.layers.Conv2D(n_filters * 2, (1, 1), padding="same")(x0)

  x1 = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(x1)
  x2 = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(x2)
  x0 = tf.keras.activations.gelu(x1) * x2
  x0 = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(x0)

  x = x + x0
  return x0

  

In [None]:
def Transformer_block(x, n_filters, num_heads):
  x = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=2, attention_axes=(3))(x, x)
  #x = MDTA(x, n_filters, num_heads)
  x = GDFN(x, n_filters)
  return x


In [None]:
def seq_transformer_blocks(x, n_filters, num_heads, num_blocks):
  for i in range(num_blocks):
    x = Transformer_block(x, n_filters, num_heads)
  return x

In [None]:
def downsample_block(x, n_filters, num_heads, num_blocks):
   f = seq_transformer_blocks(x, n_filters, num_heads, num_blocks)
   p = tf.keras.layers.Conv2D(n_filters * 2, (1, 1), padding="same")(f)
   p = tf.keras.layers.MaxPool2D((2, 2))(p)
   return f, p

In [None]:
def upsample_block(x, conv_features, n_filters, num_heads, num_blocks):
   x = tf.keras.layers.Conv2DTranspose(n_filters, (3, 3), (2, 2), padding="same")(x)
   x = tf.keras.layers.concatenate([x, conv_features])
   x = tf.keras.layers.Conv2D(n_filters, (1, 1), padding="same")(x)
   x = seq_transformer_blocks(x, n_filters, num_heads, num_blocks)
   return x

In [None]:
imageSize = 512
m = 3

def constructModel():
    inputs = tf.keras.Input(shape=(imageSize, imageSize, m))
    x = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(inputs)

    f1, p1 = downsample_block(x, m, 1, 4)
    f2, p2 = downsample_block(p1, m * 2, 2, 6)
    f3, p3 = downsample_block(p2, m * 4, 4, 6)

    bottleneck = seq_transformer_blocks(p3, m * 8, 8, 8)

    u6 = upsample_block(bottleneck, f3, m * 4, 4, 6)
    u7 = upsample_block(u6, f2, m * 2, 2, 6)
    u8 = upsample_block(u7, f1, m, 1, 4)
    refinement = seq_transformer_blocks(u8, m, 1, 4)
    outputs = tf.keras.layers.DepthwiseConv2D((3, 3), padding="same")(refinement)  

    outputs = inputs + outputs

    return tf.keras.Model(inputs, outputs, name="CNN_Transformer")

In [None]:
model_instance = constructModel()
model_instance.summary()

Model: "CNN_Transformer"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_3 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 depthwise_conv2d_178 (Depthwis  (None, 512, 512, 3)  30         ['input_3[0][0]']                
 eConv2D)                                                                                         
                                                                                                  
 multi_head_attention_88 (Multi  (None, 512, 512, 3)  33         ['depthwise_conv2d_178[0][0]',   
 HeadAttention)                                                   'depthwise_conv2d_

In [None]:
image_path = '/content/drive/MyDrive/source2'
models_path = '/content/drive/MyDrive/models/'

In [None]:
import itertools
import tensorflow as tf
from speedup import generate_out_images3
import numpy as np
from random import randint, uniform
import imageio
import time


source_num = 2799
dim = 512
stationary_defocus = 0.05


def gen():
    while True:

        layer1_number = randint(0, source_num)
        layer2_number = randint(0, source_num)
        layer3_number = randint(0, source_num)

        src1 = imageio.imread(image_path + '/image' + str(layer1_number).zfill(4) + '.png')
        src2 = imageio.imread(image_path + '/image' + str(layer2_number).zfill(4) + '.png')
        src3 = imageio.imread(image_path + '/image' + str(layer3_number).zfill(4) + '.png')
        src = np.zeros((dim, dim, m), np.double)
        src[:, :, 0] = src1[:, :, 0]
        src[:, :, 1] = src2[:, :, 0]
        src[:, :, 2] = src3[:, :, 0]
        src = src - np.amin(src)
        src = src / np.amax(src)

        w = uniform(0.05, 0.5) 
        
        a_10 = uniform(-1e3, 1e3)
        a_01 = uniform(-1e3, 1e3)
        b_20 = uniform(1, 1.5)
        b_11 = uniform(-0.1, 0.1)
        b_02 = uniform(1, 1.5)
        c_30 = uniform(-1.5e-6, 1.5e-6)
        c_21 = uniform(-2e-6, 2e-6)
        c_12 = uniform(-2e-6, 2e-6)
        c_03 = uniform(-1.5e-6, 1.5e-6)

        out = generate_out_images3(dim, m, w, stationary_defocus, a_10, a_01, b_20, b_11, b_02, c_30, c_21, c_12, c_03, src)[1]
        out = out / np.amax(out)

        #src[src > 0] = 1.

        yield (out, src)


tr_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .batch(batch_size=2).prefetch(buffer_size=8)

val_dataset = tf.data.Dataset.from_generator(
     gen, (tf.float64, tf.float64), (tf.TensorShape([dim, dim, m]), tf.TensorShape([dim, dim, m])))\
    .take(count=128).cache().batch(batch_size=2)


opt = tf.keras.optimizers.Adam()

save_best_callback = tf.keras.callbacks.ModelCheckpoint(models_path + 'bestmodel_cnn_transformer.hdf5',
                                                        save_weights_only=True,save_best_only=True, verbose=True)
csv_logger_callback = tf.keras.callbacks.CSVLogger(models_path + 'log_cnn_transformer.csv')
lr_reduce_callback = tf.keras.callbacks.ReduceLROnPlateau(factor=0.5, min_delta=5e-4, patience=5)
early_stop_callback = tf.keras.callbacks.EarlyStopping(patience=25)

model_instance = constructModel()
model_instance.compile(loss='mse', optimizer=opt, metrics=['binary_crossentropy', 'mse'])
model_instance.fit(x=tr_dataset, validation_data=val_dataset, verbose=1, validation_steps=64,
                   steps_per_epoch=256, epochs=200,
                   callbacks=[save_best_callback, csv_logger_callback, lr_reduce_callback, early_stop_callback])