<a href="https://colab.research.google.com/github/yavuzkayacan/my_colab/blob/main/conv_att_tf.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

In [2]:
from keras.layers import Layer, Conv2D, Conv3D, UpSampling2D, MaxPooling2D, MaxPool2D, MaxPool3D, Conv3DTranspose, Concatenate, Conv2DTranspose,Add, SpatialDropout2D, BatchNormalization, Dense
from tensorflow.keras.initializers import Zeros
from tensorflow.keras.activations import softmax

In [121]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super(Patches, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size[0], self.patch_size[1], 1],
            strides=[1, self.patch_size[0], self.patch_size[1], 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        batch_size,h,w,ch = images.shape
        #print(batch_size,h,w,ch)
        #batch_size = tf.shape(images)[0]
        self.ch = ch
        self.num_patches = h*w // (self.patch_size[0] * self.patch_size[1])
        patches = tf.reshape(patches, (-1, self.num_patches, self.patch_size[0], self.patch_size[1], self.ch))

        return patches

In [122]:
class SelfAttention(Layer):
    def __init__(self, n_channels, **kwargs):
        super(SelfAttention, self).__init__(**kwargs)
        self.n_channels = n_channels
        self.query = Conv3D(self.n_channels, kernel_size=1, kernel_initializer='he_normal',
                            use_bias=True, padding = "same",activation = "relu" )
        self.key = Conv3D(self.n_channels, kernel_size=1, kernel_initializer='he_normal',
                          use_bias=True, padding = "same",activation = "relu" )
        self.value = Conv3D(self.n_channels, kernel_size=1, kernel_initializer='he_normal',
                            use_bias=True, padding = "same",activation = "relu" )
        self.gamma = self.add_weight('gamma', shape=[1], initializer=Zeros(), trainable=True)

    def call(self, x):

        size = x.shape
        f, g, h = self.query(x), self.key(x), self.value(x)
        mat_mul = tf.matmul(tf.transpose(f, perm=[0, 1, 2, 4, 3]), g)
        beta = softmax(mat_mul, axis=1)
        o = self.gamma*tf.matmul(h, beta) + x

        return o

In [123]:
class Reconstruct(layers.Layer):
    def __init__(self, patch_size, num_patches, ch):
        super(Reconstruct, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.ch = ch

    def call(self,patches):

        reshaped_patches = tf.reshape(patches, (-1, self.num_patches[0], self.num_patches[1], self.patch_size[0], self.patch_size[0], self.ch))
        reshaped_patches = tf.transpose(reshaped_patches, [0, 1, 3, 2, 4, 5])
        reshaped_patches = tf.reshape(reshaped_patches, (-1, self.num_patches[0]*self.patch_size[0], self.num_patches[1]*self.patch_size[1], self.ch))

        return reshaped_patches

In [124]:
from tensorflow.keras.layers import Layer
class UnPixelShuffle(Layer):
    def __init__(self, upscale_factor, **kwargs):
        super(UnPixelShuffle, self).__init__(**kwargs)
        self.upscale_factor = upscale_factor


    def call(self, inputs, **kwargs):
        batch_size, height, width, channels = inputs.shape
        self.height = height
        self.width = width

        # Calculate the target channels after unshuffling
        self.target_channels = channels * (self.upscale_factor ** 2)
        # Reshape the input tensor
        reshaped_inputs = tf.reshape(inputs, [-1, self.height // self.upscale_factor, self.upscale_factor, self.width // self.upscale_factor, self.upscale_factor, channels])
        # Transpose and reshape to get the unshuffled result
        unshuffled_result = tf.transpose(reshaped_inputs, [0, 1, 3, 2, 4, 5])

        unshuffled_result = tf.reshape(unshuffled_result, [-1, self.height // self.upscale_factor, self.width // self.upscale_factor, self.target_channels])
        return unshuffled_result

In [125]:
class PixelShuffle(Layer):
    def __init__(self, upscale_factor, **kwargs):
        super(PixelShuffle, self).__init__(**kwargs)
        self.upscale_factor = upscale_factor

    def call(self, inputs, **kwargs):
        batch_size, height, width, channels = inputs.shape
        self.height = height
        self.width = width

        # Calculate the target channels after shuffling
        self.target_channels = channels // (self.upscale_factor ** 2)

        # Reshape the input tensor
        reshaped_inputs = tf.reshape(inputs, [-1, self.height, self.width, self.target_channels, self.upscale_factor, self.upscale_factor])

        # Transpose and reshape to get the shuffled result
        shuffled_result = tf.transpose(reshaped_inputs, [0, 1, 4, 2, 5, 3])
        shuffled_result = tf.reshape(shuffled_result, [-1, self.height * self.upscale_factor, self.width * self.upscale_factor, self.target_channels])

        return shuffled_result

In [126]:
input_shape = (256, 64, 1)
img_size = input_shape
patch_size = np.array([4, 4])
num_patches = (input_shape[0] // patch_size[0]) * (input_shape[1] // patch_size[1])
num_patch = np.array([img_size[0]//patch_size[0], img_size[1]//patch_size[1]])

In [128]:
def ConvAttention(input_shape,filter_num,kernel_size,patch_size,num_patch,depth):

  X_skip = []
  down_ch = []
  inputs = layers.Input(shape=input_shape)
  #Shallow feature extraction
  X = Conv2D(filter_num,kernel_size,padding ="same",activation="relu")(inputs)

  #Encoder
  channel_num = X.shape[-1]
  for i in range(depth):

    X_skip.append(X)

    X = Patches(patch_size)(X)

    X = SelfAttention(channel_num)(X)

    X = Reconstruct(patch_size,num_patch,channel_num)(X)

    X = UnPixelShuffle(2)(X)

    down_ch.append(channel_num)
    channel_num = channel_num * 4
    num_patch = num_patch // 2

  down_ch = down_ch[::-1]

  X_skip = X_skip[::-1]

  for i in range(depth):

    X = Patches(patch_size)(X)

    X = SelfAttention(channel_num)(X)

    X = Reconstruct(patch_size,num_patch,channel_num)(X)

    X = PixelShuffle(2)(X)
    print(X.shape)

    X = Concatenate()([X_skip[i],X])

    channel_num = channel_num // 4 + down_ch[i]
    num_patch = num_patch * 2

  out = Conv2D(1 , 3, padding="same", activation="relu")(X)

  model = keras.Model(inputs=inputs, outputs=out)

  return model





In [129]:
model = ConvAttention(input_shape,16,3,patch_size,num_patch,3)

(None, 64, 16, 256)
(None, 128, 32, 128)
(None, 256, 64, 48)


In [130]:
model.summary()

Model: "model_2"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_31 (InputLayer)       [(None, 256, 64, 1)]         0         []                            
                                                                                                  
 conv2d_22 (Conv2D)          (None, 256, 64, 16)          160       ['input_31[0][0]']            
                                                                                                  
 patches_94 (Patches)        (None, 1024, 4, 4, 16)       0         ['conv2d_22[0][0]']           
                                                                                                  
 self_attention_79 (SelfAtt  (None, 1024, 4, 4, 16)       817       ['patches_94[0][0]']          
 ention)                                                                                    