In [2]:
import tensorflow as tf


In [3]:
from tensorflow.keras import layers, Model
from tensorflow.keras.layers import Layer
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import Adam

In [None]:
class ChannelAttention(layers.Layer):
    def __init__(self, reduction_ratio=16, **kwargs):
        super(ChannelAttention, self).__init__(**kwargs)
        
        self.reduction_ratio = reduction_ratio

    def build(self, input_shape):
        channel = input_shape[-1]        

        self.average_path = keras.Sequential([
            layers.GlobalAveragePooling2D(keepdims=True),
            layers.Conv2D(channel//self.reduction_ratio, 1, activation='relu'),
            layers.Conv2D(channel, 1)])

        self.maxpool_path = keras.Sequential([
            layers.GlobalMaxPooling2D(keepdims=True),
            layers.Conv2D(channel//self.reduction_ratio, 1, activation='relu'),
            layers.Conv2D(channel, 1)])

        self.sigmoid= layers.Activation('sigmoid')

        super().build(input_shape)

    def call(self, inputs):
        # Average Pooling
        avg_pool = self.average_path(inputs)

        # Max Pooling
        max_pool = self.maxpool_path(inputs)

        final = self.sigmoid(avg_pool + max_pool)
        return inputs*final

    def get_config(self):
        config = super().get_config()
        config.update({
            'reduction_ratio': self.reduction_ratio
        })
    



In [None]:
class SpatialAttention(layers.Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super(SpatialAttention, self).__init__(**kwargs)
        self.kernel_size = kernel_size

    def build(self, input_shape):

        # what i need to do is i need to create a channel wise pooling
        self.avg_pool = layers.Lambda(lambda x : tf.reduce_mean(x,axis=-1,keepdims=True))
        self.max_pool = layers.Lambda(lambda x : tf.reduce_max(x,axis=-1,keepdims=True))

        # Concatenation layer
        self.concat = layers.Concatenate(axis=-1)
        
        self.conv = layers.Conv2D(filters=1,kernel_size=self.kernel_size,padding='same',activation='sigmoid')

        super().build(input_shape)   


    def call(self, inputs):
        # Average Pooling
        avg_pool = self.avg_pool(inputs)
        # Max Pooling
        max_pool = self.max_pool(inputs)
        # Concatenate
        final = self.concat([avg_pool,max_pool])
        # Apply Conv
        spatial_attention = self.conv(final)
        
        return inputs*spatial_attention

    def get_config(self):
        config = super().get_config()
        config.update({
            'kernel_size': self.kernel_size
        })


In [None]:
class ResidualAttentionLayer(layers.Layer):
    def __init__(self,filters=64,**kwargs):
        super().__init__(**kwargs)
        self.filters=filters
        
        self.block = keras.Sequential([ ChannelAttention(), SpatialAttention()  ])

        self.conv = layers.Conv2D(filters,3,padding='same')

        self.channel_adjust = layers.Conv2D(filters,1)

        

    def build(self,input_shape):

        super().build(input_shape)


    def call(self,inputs):
        # pass it through the attention block
        attention = self.block(inputs)
        # perform the feature extraction
        x=self.conv(attention)

        # obtain a residual connection
        residual = self.channel_adjust(inputs)


        return x+residual

    def get_config(self):
        config = super().get_config()
        config.update({
            'filters': self.filters
        })

In [None]:
class ResidualBlocks(layers.Layer):
    def __init__(self,num_layers=8,**kwargs):
        super().__init__(**kwargs)
        self.num_layers = num_layers

        # create a group of 16 ResidualAttentionLayer forming a short attention group
        self.blocks = keras.Sequential([ResidualAttentionLayer() for _ in range(self.num_layers)])

        self.conv=layers.Conv2D(64,3,padding='same')
    
    def build(self, input_shape):
        
        self.blocks.build(input_shape)
        super().build(input_shape)

    def call(self, inputs):
        x = self.blocks(inputs)
        return self.conv(x) + inputs

    def get_config(self):
        config = super().get_config()
        config.update({
            'num_layers': self.num_layers
        })

     

In [None]:
class PixelShuffleLayer(layers.Layer):  
    def __init__(self, scale_factor, filters=256, **kwargs):
        super().__init__(**kwargs)
        self.scale_factor = scale_factor
        self.filters = filters
        
        # Create layers in __init__, not build
        self.conv = layers.Conv2D(filters=filters*(scale_factor**2), kernel_size=3, padding='same')
        
        self.pixel_shuffle = layers.Lambda(
            lambda x: tf.nn.depth_to_space(x, scale_factor)
        )

    def build(self,input_shape):
        super().build(input_shape)
    
    def call(self, inputs):  
        x = self.conv(inputs)
        return self.pixel_shuffle(x)

    def get_config(self):
        config = super().get_config()
        config.update({
            'scale_factor': self.scale_factor,
            'filters':self.filters
        })

    



In [None]:
class Generator(Model):
    def __init__(self,scale_factor=2,input_shape=(64,64,3),filters=64):
        super().__init__()

        self.input_shape = input_shape
        self.filters = filters
        self.scale_factor = scale_factor
        self.initial_extraction = layers.Conv2D(64,3,padding='same')
        self.residual_group = keras.Sequential([ResidualBlocks() for _ in range(4)])
        self.conv = layers.Conv2D(64,3,padding='same')
        self.upsampling_block = keras.Sequential([PixelShuffleLayer(2) for _ in range(2)])
        self.output_layer=layers.Conv2D(3,3,padding='same',activation='tanh')

        # build the model if the shape of the input is given
        if input_shape is not None:
            self.build(input_shape=(None,) + input_shape)

        

    def build(self,input_shape):
        
        # build the first extraction layer
        self.initial_extraction.build(input_shape)
        # shape after the first convolution layer
        current_shape = input_shape[:-1] + (64,)

        # build the residual blocks
        self.residual_group.build(current_shape)

        self.conv.build(current_shape)

        # build the upsampling blocks
        
        self.upsampling_block .build(current_shape)
        # need to update the current shape as upsampling will increase the h and w dimensions
        h,w = current_shape[1],current_shape[2]
        final_shape = (current_shape[0],h*2,w*2,64*(self.scale_factor)**2)

        # build the output layers
        self.output_layer.build(final_shape)

        super().build(input_shape)

    def call(self,inputs):

        x = self.initial_extraction(inputs)
        residual = x

        # pass the input throught the residual whole map attention blocks
        x = self.residual_group(x)
        x = self.conv(x)
        x = self.upsampling_block(x+residual)
        x = self.output_layer(x)

        return x

    def get_config(self):
        config=super().get_config()

        config.update({
            "input_shape": self.input_shape,
            "filters": self.filters,
            "scale_factor": self.scale_factor,
            
        })
        return config
        
        
        
        
        

In [None]:
gen_model=Generator()

In [None]:
gen_model.summary(expand_nested=True)