In [None]:
import tensorflow as tf
import numpy as np

In [None]:
from keras.layers import BatchNormalization, Activation
from keras.layers import Conv2D, MaxPool2D, UpSampling2D
from keras.layers import Add, Multiply

from keras.layers import Input, Dense, AveragePooling2D, Flatten
from keras.models import Model
from keras.regularizers import l2

from keras.preprocessing.image import load_img, img_to_array, ImageDataGenerator

In [None]:
from google.colab import drive
drive.mount('/content/drive/', force_remount=True)

## Data uploading

In [None]:
img_width, img_height, channels = 160,  160, 3
input_shape = (img_width, img_height, 3)
batch_size = 64

In [None]:
merged_imgs_dir = '<project_path>/data'

In [None]:
# Images preprocessing
imgs_datagen = ImageDataGenerator(rescale=1. / 255,
                                  featurewise_center=True,
                                  featurewise_std_normalization=True,
                                  validation_split=0.08)

train_generator = imgs_datagen.flow_from_directory(
    merged_imgs_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='training')

validation_generator = imgs_datagen.flow_from_directory(
    merged_imgs_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical',
    subset='validation')

## Model structure

In [None]:
def pre_activation_residual_unit(unit_input, n_input_ch=None, 
                                 n_output_ch=None, stride=1):
    if n_output_ch is None:
        n_output_ch = unit_input.get_shape()[-1]
    if n_input_ch == None:
        n_input_ch = n_output_ch // 4

    x = BatchNormalization()(input)
    x = Activation('relu')(x)
    x = Conv2D(n_input_ch, (1, 1), padding='same', strides=stride)(x)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(n_input_ch, (3, 3), padding='same', strides=1)(x)

    x = BatchNormalization()(x)
    x = Activation('relu')(x)
    x = Conv2D(n_output_ch, (1, 1), padding='same', strides=1)(x)

    if n_input_ch != n_output_ch or stride != 1:
        unit_input = Conv2D(n_output_ch, (1, 1), 
                       padding='same', 
                       strides=(stride, stride))(unit_input)

    x_added = Add()([x, unit_input])
    return x_added

In [None]:
def attention_module(input, encoder_depth=1):
    # Hyperparameters according to the article:
    # number of preprocessing Residual Units before splitting into trunk batch and mask batch
    p = 1
    # number of Residual Units in trunk batch
    t = 2
    # number of Residual Units between adjacent pooling layer in the mask branch
    r = 1 

    n_input_ch = input.get_shape()[-1]
    n_output_ch = n_input_ch

    # First p Residual Units
    residual_output = input
    for _ in range(p):
        residual_output = pre_activation_residual_unit(residual_output)

    # ---------------------------- Trunk Branch part ----------------------------
    trunk_output = residual_output
    for _ in range(t):
        trunk_output = pre_activation_residual_unit(trunk_output)

    # -------------------------- Soft Mask Branch part --------------------------
    # First down sampling
    down_sampling_output = MaxPool2D(padding='same')(residual_output)

    # Apply r Resudual Units after down sampling
    residual_output = down_sampling_output
    for _ in range(r):
    residual_output = pre_activation_residual_unit(residual_output)

    soft_mask_output = residual_output

    # Down sampling - up sampling part (with skip connections)
    skip_connections = []
    # Down sampling part
    for _ in range(encoder_depth - 1):
    # create skip connection between bottom-up and top-down parts 
    skip_connection_output = pre_activation_residual_unit(residual_output)
    # print('Skip connection shape:', skip_connection_output.shape)
    skip_connections.append(skip_connection_output)

    # apply down sampling
    down_sampling_output = MaxPool2D(padding='same')(residual_output)

    # apply r Residual Units 
    residual_output = down_sampling_output
    for _ in range(r):
        residual_output = pre_activation_residual_unit(residual_output)

    # reverse skip connections list (we will add connections in reverse order)
    skip_connections = list(reversed(skip_connections))

    # Up sampling part
    for i in range(encoder_depth - 1):
    # apply r Residual Units 
    for _ in range(r):
        residual_output = pre_activation_residual_unit(residual_output)
    # apply up sampling
    up_sampling_output = UpSampling2D()(residual_output)

    # adding skip connections
    soft_mask_output = Add()([up_sampling_output, skip_connections[i]])

    residual_output = soft_mask_output
    # Final r Residual Units
    for _ in range(r):
        residual_output = pre_activation_residual_unit(residual_output)

    # Final up sampling    
    up_sampling_output = UpSampling2D()(residual_output)

    conv_output = Conv2D(n_input_ch, (1, 1))(up_sampling_output)
    conv_output = Conv2D(n_input_ch, (1, 1))(conv_output)
    soft_mask_output = Activation('sigmoid')(conv_output)

    # ------------- Truck and Soft Mask Branches concatenation part --------------

    output = Multiply()([trunk_output, soft_mask_output])
    output = Add()([trunk_output, output])

    # Final p Residual Units
    for _ in range(p):
        output = pre_activation_residual_unit(output)

    return output

In [None]:
def att_resnet_56(shape=(160, 160, 3), n_channels=64, 
                      n_classes=9, l2_par=0.01):

    reg = l2(l2_par)

    model_input = Input(shape=shape)
    x = Conv2D(n_channels, (7, 7), strides=(2, 2), padding='same')(model_input) # shape after: 80x80
    x = MaxPool2D(pool_size=(3, 3), strides=(2, 2), padding='same')(x)  # shape after: 40x40

    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 4)  # shape after: 40x40
    x = attention_module(x, encoder_depth=2)  # shape after: 40x40

    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 8, stride=2)  # shape after: 20x20
    x = attention_module(x, encoder_depth=2) # shape after: 20x20

    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 16, stride=2)  # shape after: 10x10
    x = attention_module(x, encoder_depth=1)  # shape after: 10x10

    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 32, stride=2)  # shape after: 5x5
    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 32) # shape after: 5x5
    x = pre_activation_residual_unit(x, n_output_ch=n_channels * 32) # shape after: 5x5

    pool_size = (x.shape[1], x.shape[2])
    x = AveragePooling2D(pool_size=pool_size, strides=(1, 1))(x) # shape after: 1x1
    x = Flatten()(x)

    model_output = Dense(n_classes, kernel_regularizer=reg, activation='softmax')(x)
    model = Model(model_input, model_output)
    return model

## Model training

In [None]:
from keras.optimizers import Adam
from keras.callbacks import ReduceLROnPlateau, EarlyStopping, ModelCheckpoint

In [None]:
att_resnet_model = att_resnet_56(n_classes=9)

In [None]:
att_resnet_model.compile(Adam(lr=0.0001), 
                         loss='categorical_crossentropy', 
                         metrics=['accuracy'])

In [None]:
reducer = ReduceLROnPlateau(monitor='val_accuracy', factor=0.2, patience=5, 
                            min_lr=10e-7, min_delta=0.001, verbose=1)
stopper = EarlyStopping(monitor='val_accuracy', min_delta=0, 
                        patience=5, verbose=1)

filepath="<project_path>/att_resnet_best_weights.{epoch:02d}-{val_accuracy:.4f}"
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, 
                             save_best_only=True, mode='max')

model_callbacks= [reducer, stopper, checkpoint]

In [None]:
batch_size = 128

att_resnet_model.fit_generator(train_generator,
                    steps_per_epoch=train_generator.samples//batch_size, 
                    epochs=5,
                    validation_data=validation_generator, 
                    validation_steps=validation_generator.samples//batch_size,
                    callbacks=model_callbacks, initial_epoch=0)