In [1]:
import os
os.chdir('/content/drive/MyDrive/')

In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import random
import cv2
from PIL import Image
from os import listdir
from tensorflow.keras.layers import Conv2D, MaxPool2D, UpSampling2D, AveragePooling2D, Input,
from tensorflow.keras.layers import BatchNormalization, Activation, Add, Multiply, Flatten
from tensorflow.keras.models import Model

%matplotlib inline
%load_ext autoreload

In [3]:
os.chdir('/content/drive/MyDrive/GR5242_Project_Folder/GR5242_Project/')

In [4]:
input_dir = '/content/drive/MyDrive/GR5242_Project_Folder/GR5242_Project/train_data_generated_2/input'
label_dir = '/content/drive/MyDrive/GR5242_Project_Folder/GR5242_Project/train_data_generated_2/label'

In [5]:
input_dir_2 = '/content/drive/MyDrive/GR5242_Project_Folder/GR5242_Project/train_data_generated/input'
label_dir_2 = '/content/drive/MyDrive/GR5242_Project_Folder/GR5242_Project/train_data_generated/label'

In [6]:
image_list1 = []
for filename in os.listdir(input_dir)[:500]:
    im=cv2.imread(input_dir + '/' + filename)
    im = tf.image.resize(im,[800,800])
    image_list1.append(im)

In [7]:
image_list2 = []
for filename in os.listdir(label_dir)[:500]:
    im=cv2.imread(label_dir + '/' + filename)
    im = tf.image.resize(im,[200,200])
    image_list2.append(im)

In [8]:
image_list3 = []
for filename in os.listdir(input_dir_2)[:500]:
    im = cv2.imread(input_dir_2 + '/' +filename)
    im = tf.image.resize(im,[800,800])
    image_list3.append(im)

In [9]:
image_list4 = []
for filename in os.listdir(label_dir_2)[:500]:
    im=cv2.imread(label_dir_2 + '/' +filename)
    im = tf.image.resize(im,[200,200])
    image_list4.append(im)

In [10]:
image_list5 = []
for filename in os.listdir(input_dir)[2000:2100]:
    im=cv2.imread(input_dir + '/' +filename)
    im = tf.image.resize(im,[200,200])
    image_list5.append(im)

image_list6 = []
for filename in os.listdir(label_dir)[2000:2100]:
    im=cv2.imread(label_dir + '/' +filename)
    im = tf.image.resize(im,[200,200])
    image_list6.append(im)

In [11]:
X_train = np.asarray(image_list1+image_list3,dtype=np.float32)
y_train = np.asarray(image_list2+image_list4,dtype=np.float32)
x_test = np.asarray(image_list5,dtype=np.float32)
y_test = np.asarray(image_list6,dtype=np.float32)

In [12]:
class ResidualAttentionNetwork():
    def __init__(self, input_shape, output_size, p=1, t=2, r=1, 
                 filter_dic = {'s1': [16,16,64],
                               's2': [32,32,128],
                               's3': [64,64,256],
                               'se': [128,128,512]}):
        self.input_shape = input_shape
        self.output_size = output_size
        self.p = p
        self.t = t
        self.r = r
        self.filter_dic = filter_dic
      

    def Attention_56(self):

        input_data = Input(shape=self.input_shape) 
        convolution_layer_1 = Conv2D(filters=64, kernel_size=(3,3), strides=(1,1), padding='same')(input_data)  

        # Residual-Attention Module stage #1 
        residual_unit_1 = self.ResidualUnit(convolution_layer_1, filters=[16,16,64], residual_unit_type='in module')
        attention_module_unit_1 = self.AttentionModuleStage1(residual_unit_1, filters=[16,16,64], learning_mechanism ='ARL')  
        
        # Residual-Attention Module stage #2
        residual_unit_2 = self.ResidualUnit(attention_module_unit_1, filters=[32,32,128], residual_unit_type='out module')
        attention_module_unit_2 = self.AttentionModuleStage2(residual_unit_2, filters=[32,32,128], learning_mechanism='ARL')  
      
        # Residual-Attention Module stage #3
        residual_unit_3 = self.ResidualUnit(attention_module_unit_2, filters=[64,64,256], residual_unit_type='out module')
        attention_module_unit_3 = self.AttentionModuleStage3(residual_unit_3, filters=[64,64,256], learning_mechanism='ARL')  

        for _ in range(2):
            attention_module_unit_3 = self.ResidualUnit(attention_module_unit_3, filters=[128,128,512], residual_unit_type='in module')  

        convolution_layer_2 = Conv2D(filters=256, kernel_size=(1,1), padding='same', activation='relu')(attention_module_unit_3)  
        batch_norm_layer_1 = BatchNormalization()(convolution_layer_2)
        convolution_layer_3 = Conv2D(filters=256, kernel_size=(1,1), padding='same', activation='relu')(batch_norm_layer_1)  
        batch_norm_layer_2 = BatchNormalization()(convolution_layer_3)
        convolution_layer_3 = Conv2D(filters=3, kernel_size=(1,1), padding='same', activation='relu')(batch_norm_layer_2)  
        batch_norm_layer_3 = BatchNormalization()(convolution_layer_3)

        model = Model(inputs=input_data, outputs=batch_norm_layer_3)
        
        return model

    
    def ResidualUnit(self, residual_input, filters, residual_unit_type='in module'):

        identity_x = residual_input

        batch_norm_layer_1 = BatchNormalization()(residual_input)
        activation_layer_1 = Activation('relu')(batch_norm_layer_1)
        convolution_layer_1 = Conv2D(filters=[16,16,64], kernel_size=(1,1), padding='same')(activation_layer_1)
        
        batch_norm_layer_1 = BatchNormalization()(convolution_layer_1)
        activation_layer_2 = Activation('relu')(batch_norm_layer_1)
        
        if residual_unit_type == 'in module':
            convolution_layer_2 = Conv2D(filters=[32,32,128], kernel_size=(3,3), strides=(1,1), padding='same')(activation_layer_2)
        else: 
            convolution_layer_2 = Conv2D(filters=[32,32,128], kernel_size=(3,3), strides=(2,2), padding='same')(activation_layer_2)

        batch_norm_layer_2 = BatchNormalization()(convolution_layer_2)
        activation_layer_3 = Activation('relu')(batch_norm_layer_2)
        convolution_layer_3 = Conv2D(filters=[64,64,256], kernel_size=(1,1), padding='same')(activation_layer_3)

        if identity_x.shape != convolution_layer_3.shape:
            filter_update = convolution_layer_3.shape[-1]
            if residual_unit_type == 'in module':
                identity_x = Conv2D(filters=filter_update, kernel_size=(1,1),strides=(1,1), padding='same')(identity_x) 
            else:  
                identity_x = Conv2D(filters=filter_update, kernel_size=(3,3),strides=(2,2), padding='same')(identity_x) 

        output = Add()([identity_x, convolution_layer_3])
        
        return output
    
    
    def AttentionResidualLearning(self, trunk_unit, soft_mask_unit):
 
        output = Multiply()([trunk_unit, soft_mask_unit])
        output = Add()([output, trunk_unit])

        return output   
        
        
    def AttentionModuleStage1(self, input_unit, filters, learning_mechanism):
        
        for _ in range(self.p):
            attention_module_unit_1 = self.ResidualUnit(input_unit, filters, residual_unit_type='in module')
        
        #trunk branch
        for _ in range(self.t):
            trunk_unit = self.ResidualUnit(attention_module_unit_1, filters, residual_unit_type='in module')
        
        #soft_mask_branch with 2 skip connections
        down_sampling_unit_1 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(attention_module_unit_1)
        for _ in range(self.r):
            down_sampling_unit_1 = self.ResidualUnit(down_sampling_unit_1, filters, residual_unit_type='in module')
        
        skip_unit_1 = self.ResidualUnit(down_sampling_unit_1, filters, residual_unit_type='in module')
        
        down_sampling_unit_2 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(down_sampling_unit_1)
        for _ in range(self.r):
            down_sampling_unit_2 = self.ResidualUnit(down_sampling_unit_2, filters, residual_unit_type='in module')
        
        skip_init_2 = self.ResidualUnit(down_sampling_unit_2, filters, residual_unit_type='in module')
        
        down_sampling_unit_3 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(down_sampling_unit_2)
        
        for _ in range(self.r * 2):
            down_sampling_unit_3 = self.ResidualUnit(down_sampling_unit_3, filters, residual_unit_type='in module')
        us_unit_1 = UpSampling2D(size=(2,2))(down_sampling_unit_3) 
        
        add_unit_1 = Add()([us_unit_1, skip_init_2])
        for _ in range(self.r):
            add_unit_1 = self.ResidualUnit(add_unit_1, filters, residual_unit_type='in module')
        up_sampling_unit_2 = UpSampling2D(size=(2,2))(add_unit_1) 
        
        add_unit_2 = Add()([up_sampling_unit_2, skip_unit_1])
        for _ in range(self.r):
            add_unit_2 = self.ResidualUnit(add_unit_2, filters, residual_unit_type='in module')
        up_sampling_unit_3 = UpSampling2D(size=(2,2))(add_unit_2) 
        
        convolution_filter = up_sampling_unit_3.shape[-1]
        convolution_layer_1 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(up_sampling_unit_3)
        convolution_layer_2 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(convolution_layer_1)
        soft_mask_unit = Activation('sigmoid')(convolution_layer_2)
        output_unit = self.AttentionResidualLearning(trunk_unit, soft_mask_unit)
        
        for _ in range(self.p):
            output_unit = self.ResidualUnit(output_unit, filters)
        
        return output_unit
        
    
    def AttentionModuleStage2(self, input_unit, filters, learning_mechanism):
        
        for _ in range(self.p):
            attention_module_unit_1 = self.ResidualUnit(input_unit, filters, residual_unit_type='in module')
        
        #trunk branch
        for _ in range(self.t):
            trunk_unit = self.ResidualUnit(attention_module_unit_1, filters, residual_unit_type='in module')
        
        #soft_mask_branch with 1 skip connections
        down_sampling_unit_1 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(attention_module_unit_1)
        for _ in range(self.r):
            down_sampling_unit_1 = self.ResidualUnit(down_sampling_unit_1, filters, residual_unit_type='in module')

        skip_unit_outside = self.ResidualUnit(down_sampling_unit_1, filters, residual_unit_type='in module')
        
        down_sampling_unit_3 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(down_sampling_unit_1)
        for _ in range(self.r * 2):
            down_sampling_unit_3 = self.ResidualUnit(down_sampling_unit_3, filters, residual_unit_type='in module')
        up_sampling_unit_1 = UpSampling2D(size=(2,2))(down_sampling_unit_3) 

        add_unit_2 = Add()([up_sampling_unit_1, skip_unit_outside])
        for _ in range(self.r):
            add_unit_2 = self.ResidualUnit(add_unit_2, filters, residual_unit_type='in module')
        up_sampling_unit_3 = UpSampling2D(size=(2,2))(add_unit_2) 
        
        convolution_filter = up_sampling_unit_3.shape[-1]
        convolution_layer_1 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(up_sampling_unit_3)
        convolution_layer_2 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(convolution_layer_1)
        soft_mask_unit = Activation('sigmoid')(convolution_layer_2)
        output_unit = self.AttentionResidualLearning(trunk_unit, soft_mask_unit)
        
        for _ in range(self.p):
            output_unit = self.ResidualUnit(output_unit, filters)
            
        return output_unit
        
        
    def AttentionModuleStage3(self, input_unit, filters, learning_mechanism):
        
        for _ in range(self.p):
            attention_module_unit_1 = self.ResidualUnit(input_unit, filters, residual_unit_type='in module')
        
        #trunk branch
        for _ in range(self.t):
            trunk_unit = self.ResidualUnit(attention_module_unit_1, filters, residual_unit_type='in module')
        
        #soft_mask_branch without skip connection
        down_sampling_unit_1 = MaxPool2D(pool_size=(3,3), strides=(2,2), padding='same')(attention_module_unit_1)
        for _ in range(self.r):
            down_sampling_unit_1 = self.ResidualUnit(down_sampling_unit_1, filters, residual_unit_type='in module')
        up_sampling_unit_1 = UpSampling2D(size=(2,2))(down_sampling_unit_1) 
        
        convolution_filter = up_sampling_unit_1.shape[-1]
        convolution_layer_1 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(up_sampling_unit_1)
        convolution_layer_2 = Conv2D(filters=convolution_filter, kernel_size=(1,1), padding='same')(convolution_layer_1)
        soft_mask_unit = Activation('sigmoid')(convolution_layer_2)
        output_unit = self.AttentionResidualLearning(trunk_unit, soft_mask_unit)
        
        for _ in range(self.p):
            output_unit = self.ResidualUnit(output_unit, filters)
            
        return output_unit

In [13]:
input_shape = (800,800,3)
output_size = 3

In [14]:
checkpoint_filepath = '/content/drive/MyDrive/'
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_weights_only=True,
    monitor='loss',
    mode='auto',
    save_best_only=True)

In [18]:
def scheduler(epoch, lr):
    if epoch % 10 == 0:
        return lr * 0.8
    else:
        return lr

model = ResidualAttentionNetwork(input_shape=input_shape, output_size=output_size).Attention_56()
model.compile(tf.keras.optimizers.Adam(),#.SGD(lr=1e-1, decay=1e-4, momentum=0.9, nesterov=True),
              loss='mean_squared_error',)

callback_loss = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=10)
callback_lr = tf.keras.callbacks.LearningRateScheduler(scheduler)

history = model.fit(X_train, y_train, batch_size=1, epochs=20, callbacks=[callback_loss, callback_lr, model_checkpoint_callback])

(None, 200, 200, 3)
Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [28]:
history2 = model.fit(X_train, y_train, batch_size=1, epochs=20, callbacks=[callback_loss, callback_lr, model_checkpoint_callback])

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


In [34]:
model.save('Residual_attention_network.h5')

In [23]:
simple_img = np.expand_dims(cv2.imread(input_dir + '/' +'3762.png').astype('int32'), axis=0)
simple_img = tf.image.resize(simple_img,[800,800])

mid_img = np.expand_dims(cv2.imread(input_dir + '/' +'2870.png').astype('int32'), axis=0)
mid_img = tf.image.resize(mid_img,[800,800])

hard_img = np.expand_dims(cv2.imread(input_dir + '/' +'3950.png').astype('int32'), axis=0)
hard_img = tf.image.resize(hard_img,[800,800])

In [29]:
simple_pred =  model.predict(simple_img)
mid_pred =  model.predict(mid_img)
hard_pred =  model.predict(hard_img)

In [None]:
plt.imshow(simple_pred[0,:,:,:].astype('int32'))
plt.axis('off')

In [None]:
plt.imshow(simple_pred[0,:,:,:].astype('int32'))
plt.axis('off')

In [None]:
plt.imshow(simple_pred[0,:,:,:].astype('int32'))
plt.axis('off')

In [None]:
cv2.imwrite('model20th_2_simple_140_epoch.png', simple_pred[0,:,:,:].astype('int32')) 
cv2.imwrite('model20th_2_mid_140_epoch.png', mid_pred[0,:,:,:].astype('int32')) 
cv2.imwrite('model20th_2_hard_140_epoch.png', hard_pred[0,:,:,:].astype('int32')) 