<center><h1>Two stage image classifier to utilize full size image information</h1></center>

In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, Activation,Multiply,Dense,Add
from tensorflow.keras.layers import UpSampling2D, Flatten,add, Concatenate, MaxPooling2D, AveragePooling2D,GlobalAveragePooling2D,Dropout
from tensorflow.keras.layers import TimeDistributed


In [None]:

# Squeeze and Excitation
def se_block(input_t, channels, r=4):
    # Squeeze
    x = GlobalAveragePooling2D()(input_t)
    # Excitation
    x = Dense(channels//r, activation="relu")(x)
    x = Dense(channels, activation="sigmoid")(x)
    return Multiply()([input_t, x])


In [None]:


NUM_CLASSES = 5
IMG_SIZE=200

def patch_wise_extractor():
    
    input_img= Input( shape= (IMG_SIZE, IMG_SIZE, 3))
    
    act = tf.nn.swish
    #Block1
    id_x1 = Conv2D(24,kernel_size=3, strides=1, padding="same")(input_img)
    x = BatchNormalization()(id_x1)
    x = Activation(act)(x)
    x = Conv2D(24,kernel_size=3, strides=1, padding="same")(x)
    x = Add()([id_x1,x])
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(24,kernel_size=2, strides=2, padding="valid")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = se_block(x,int(x.shape[-1]))
    
    #Block2
    id_x2 = Conv2D(48,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(id_x2)
    x = Activation(act)(x)
    x = Conv2D(48,kernel_size=3, strides=1, padding="same")(x)
    x = Add()([id_x2,x])
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(48,kernel_size=2, strides=2, padding="valid")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)    
    x = se_block(x,int(x.shape[-1]))

    #Block3
    id_x3 = Conv2D(64,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(64,kernel_size=3, strides=1, padding="same")(x)
    x = Add()([id_x3,x])
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(64,kernel_size=2, strides=2, padding="valid")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)  
    x = se_block(x,int(x.shape[-1]))

    
    #Block4
    id_x4 = Conv2D(128,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(128,kernel_size=3, strides=1, padding="same")(x)
    x = Add()([id_x4,x])
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(128,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)  
    x = se_block(x,int(x.shape[-1]))

    
    #Block5
    id_x5 = Conv2D(256,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(256,kernel_size=3, strides=1, padding="same")(x)
    x = Add()([id_x5,x])
    x = BatchNormalization()(x)
    x = Activation(act)(x)
    x = Conv2D(256,kernel_size=3, strides=1, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation(act)(x) 
    x = se_block(x,int(x.shape[-1]))
    
    #final point-wise convolution
    x = Conv2D(256,kernel_size=1, strides=1, padding="same")(x)
    x = se_block(x,int(x.shape[-1]))
    
    ext_out = GlobalAveragePooling2D()(x)
    
    #x = Dropout(0.3)(x)
    
    #out = Dense(NUM_CLASSES,activation="softmax")(x)
    
    model = Model(input_img, ext_out, name="patch_wise")
 

    return model

pw_ext = patch_wise_extractor()
pw_ext.summary()

In [None]:
input_stack= Input( shape= (12,IMG_SIZE, IMG_SIZE, 3))

ext_stack = TimeDistributed(pw_ext)(input_stack)

flat = Flatten()(ext_stack)

drop = Dropout(0.5)(flat)

out = Dense(NUM_CLASSES,"softmax")(drop)

model = Model(input_stack,out)

model.summary()