developed by : __*aipythoner@gmail.com*__

[pytorch version](https://github.com/achaiah/pywick/blob/aa69ad19aae74fae65e5fa04bbca93cd916739e2/pywick/models/classification/polynet.py)

[Poster, Structure](http://mmlab.ie.cuhk.edu.hk/projects/cu_deeplink/)

[Paper (arxiv)](https://arxiv.org/abs/1611.05725)

In [1]:
import warnings; warnings.simplefilter('ignore')
import keras, math, numpy as np
import tensorflow as tf
from keras.layers import Layer, Input, Dropout, Conv2D, MaxPool2D, AvgPool2D, Flatten, Dense
from keras.layers import BatchNormalization, concatenate, add, Activation, GlobalAvgPool2D, Lambda
from keras.models import Model
from keras.regularizers import l2
from keras.optimizers import SGD
from keras.callbacks import LearningRateScheduler
import keras.backend as K

Using TensorFlow backend.


In [2]:
num_classes = 1000 # 1000 for ImageNet

In [5]:
def Conv2D_BN(input_x, filters, kernel_size, strides=(1,1), padding='valid', relu=True, name=None):
    conv = Conv2D(filters, kernel_size, strides=strides, padding=padding, 
                  kernel_regularizer=l2(0.00004), name=name)(input_x)
    bn = BatchNormalization(momentum=0.9997)(conv)
    if relu:
        res = Activation('relu')(bn)
        return res
    else:
        return bn

In [6]:
def stem(input_x):
    conv3x3 = Conv2D_BN(input_x, 32, (3,3), strides=(2,2))
    conv3x3_2 = Conv2D_BN(conv3x3, 32, (3,3))
    conv3x3_3 = Conv2D_BN(conv3x3_2, 64, (3,3), padding="same")
    
    b1_maxpool = MaxPool2D((3,3), strides=(2,2))(conv3x3_3)
    b2_conv3x3 = Conv2D_BN(conv3x3_3, 96, (3,3), strides=(2,2))
    conc1 = concatenate([b1_maxpool, b2_conv3x3], name="Input-Module_Concat1")
#     First Branch
    b1_conv1x1 = Conv2D_BN(conc1, 64, (1,1), padding='same')
    b1_conv3x3 = Conv2D_BN(b1_conv1x1, 96, (3,3))
#     Second Branch
    b2_conv1x1 = Conv2D_BN(conc1, 64, (1,1), padding='same')
    b2_conv7x1 = Conv2D_BN(b2_conv1x1, 64, (7,1), padding="same")
    b2_conv1x7 = Conv2D_BN(b2_conv7x1, 64, (1,7), padding="same")
    b2_conv3x3_1 = Conv2D_BN(b2_conv1x7, 96, (3,3))
    conc2 = concatenate([b1_conv3x3, b2_conv3x3_1], name="Input-Module_Concat2")
    
#     paper problem(need stride of 2)
    b1_conv3x3_1 = Conv2D_BN(conc2, 192, (3,3), strides=(2,2))
    b2_maxpool2 = MaxPool2D((3,3), strides=(2,2))(conc2)
    conc3 = concatenate([b1_conv3x3_1, b2_maxpool2], name="Stem")
    
    return conc3

In [7]:
def blocka(input_x, name='Block-A'):
    
    b1_conv1x1 = Conv2D_BN(input_x, 32, (1,1), padding='same')
    b1_conv3x3 = Conv2D_BN(b1_conv1x1, 48, (3,3), padding='same')
    b1_conv3x3_1 = Conv2D_BN(b1_conv3x3, 64, (3,3), padding='same')
    
    b2_conv1x1 = Conv2D_BN(input_x, 32, (1,1), padding='same')
    b2_conv3x3 = Conv2D_BN(b2_conv1x1, 32, (3,3), padding='same')
   
    b3_conv1x1 = Conv2D_BN(input_x, 32, (1,1), padding='same')
    
    concat = concatenate([b1_conv3x3_1, b2_conv3x3, b3_conv1x1])
    conv1x1 = Conv2D_BN(concat, 384, (1,1), padding='same', relu=False)
    
    return conv1x1

In [8]:
def blockb(input_x, name='Block-B'):
    
    b1_conv1x1 = Conv2D_BN(input_x, 128, (1,1), padding='same')
    b1_conv1x7 = Conv2D_BN(b1_conv1x1, 160, (1,7), padding='same') #before 'same'
    b1_conv7x1 = Conv2D_BN(b1_conv1x7, 192, (7,1), padding='same') #before 'same'
    
    b2_conv1x1 = Conv2D_BN(input_x, 192, (1,1), padding='same')
    
    concat = concatenate([b1_conv7x1, b2_conv1x1])
    conv1x1 = Conv2D_BN(concat, 1152, (1,1), padding='same', relu=False)
    
    return conv1x1

In [9]:
def blockc(input_x, name='Block-C'):
    
    b1_conv1x1 = Conv2D_BN(input_x, 192, (1,1), padding='same')
    b1_conv1x3 = Conv2D_BN(b1_conv1x1, 224, (1,3), padding='same')
    b1_conv3x1 = Conv2D_BN(b1_conv1x3, 256, (3,1), padding='same')
    
    b2_conv1x1 = Conv2D_BN(input_x, 192, (1,1), padding='same')
    
    concat = concatenate([b1_conv3x1, b2_conv1x1])
    conv1x1 = Conv2D_BN(concat, 2048, (1,1), padding='same', relu=False)
    
    return conv1x1

In [10]:
def reduction_a(input_x, name="Reduction_A"):

    b1_conv1x1 = Conv2D_BN(input_x, 256, (1,1), padding="same")
    b1_conv3x3 = Conv2D_BN(b1_conv1x1, 256, (3,3), padding="same")
    b1_conv3x3_1 = Conv2D_BN(b1_conv3x3, 384, (3,3), strides=(2,2))
    
    b2_conv3x3 = Conv2D_BN(input_x, 384, (3,3), strides=(2,2))
    
    b3_maxpool = MaxPool2D((3,3), strides=(2,2))(input_x)
    
    res = concatenate([b3_maxpool, b2_conv3x3, b1_conv3x3_1], name=name)
    
    return res

In [11]:
def reduction_b(input_x, name="Reduction_B"):
    
    b1_conv1x1 = Conv2D_BN(input_x, 256, (1,1))
    b1_conv3x3 = Conv2D_BN(b1_conv1x1, 256, (3,3), padding='same')
    b1_conv3x3_1 = Conv2D_BN(b1_conv3x3, 256, (3,3), strides=(2,2))
    
    b2_conv1x1 = Conv2D_BN(input_x, 256, (1,1))
    b2_conv3x3 = Conv2D_BN(b2_conv1x1, 256, (3,3), strides=(2,2))
    
    b3_conv1x1 = Conv2D_BN(input_x, 256, (1,1))
    b3_conv3x3 = Conv2D_BN(b2_conv1x1, 384, (3,3), strides=(2,2))
    
    b4_maxpool = MaxPool2D((3,3), strides=(2,2))(input_x)
    
    res = concatenate([b1_conv3x3_1, b2_conv3x3, b3_conv3x3, b4_maxpool])
    return res

In [12]:
def multiway(input_x, scale, block, num_block):
    """Constructing N-way modules."""
    assert num_block >=1, 'num_block should greater or equal 1'
    blocks = [block for _ in range(num_block)]
    out = input_x
    for block in blocks:
        out = Lambda(lambda x: out + x * scale)(block(out))
    out = Activation('relu')(out)
    return out

In [13]:
# def PolyConv2d(input_x, num_blocks):
#     """block of poly-N """
# NOT IMPLEMENTED problem : # [block_index]

In [14]:
def inceptionResnetBpoly(input_x, scale, num_blocks):
    """Base class for constructing poly-N Inception-ResNet-B modules."""
    assert num_blocks >=1, 'num_blocks >= 1.'
    
    out = input_x
    for indx in range(num_blocks):
        x = blockb(out)
        out = Lambda(lambda x: out + x * scale)(x)
        x = Activation('relu')(x)
    out = Activation('relu')(out)
    return out

In [15]:
def inceptionResnetCpoly(input_x, scale, num_blocks):
    """Base class for constructing poly-N Inception-ResNet-C modules."""
    assert num_blocks >=1, 'num_blocks >= 1.'
    
    out = input_x
    for indx in range(num_blocks):
        x = blockc(out)
        out = Lambda(lambda x: out + x * scale)(x)
        x = Activation('relu')(x)
    out = Activation('relu')(out)
    return out

In [16]:
def stageA2way(input_x, scale):
    res = multiway(input_x, scale, block=blocka, num_block=2)
    return res

In [17]:
def stageBPoly3(input_x, scale):
    res = inceptionResnetBpoly(input_x, scale, num_blocks=3)
    return res

In [18]:
def stageB2way(input_x, scale):
    res = multiway(input_x, scale, block=blockb, num_block=2)
    return res

In [19]:
def stageCPoly3(input_x, scale):
    res = inceptionResnetCpoly(input_x, scale, num_blocks=3)
    return res

In [20]:
def stageC2way(input_x, scale):
    res = multiway(input_x, scale, block=blockc, num_block=2)
    return res

In [21]:
scalea = [1, 0.992308, 0.984615, 0.976923, 0.969231, 0.961538, \
          0.953846, 0.946154, 0.938462, 0.930769]

scaleb = [0.923077, 0.915385, 0.907692, 0.9, 0.892308, 0.884615, \
          0.876923, 0.869231, 0.861538, 0.853846, 0.846154, \
          0.838462, 0.830769, 0.823077, 0.815385, 0.807692, 0.8, \
          0.792308, 0.784615, 0.776923]

scalec = [0.769231, 0.761538, 0.753846, 0.746154, 0.738462, \
          0.730769, 0.723077, 0.715385, 0.707692, 0.7]

In [22]:
input_layer = Input((331,331,3))
x = stem(input_layer)

for scale in scalea:
    x = stageA2way(x, scale)   

x = reduction_a(x)

for i in range(0, len(scaleb), 2):
    x = stageBPoly3(x, scaleb[i])
    x = stageB2way(x, scaleb[i+1])
    
x = reduction_b(x)

for i in range(0, len(scalec), 2):
    x = stageCPoly3(x, scalec[i])
    x = stageC2way(x, scalec[i+1])
    
avgpool = GlobalAvgPool2D()(x)
drop = Dropout(rate=0.2)(avgpool)
fc = Dense(num_classes, activation="softmax", name="Final_Prob")(drop)

Instructions for updating:
Colocations handled automatically by placer.


In [23]:
model = Model(inputs=input_layer, outputs=fc)
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 331, 331, 3)  0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 165, 165, 32) 896         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 165, 165, 32) 128         conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 165, 165, 32) 0           batch_normalization_1[0][0]      
____________________________________________________________________________________________

concatenate_46 (Concatenate)    (None, 19, 19, 384)  0           activation_273[0][0]             
                                                                 activation_274[0][0]             
__________________________________________________________________________________________________
conv2d_285 (Conv2D)             (None, 19, 19, 1152) 443520      concatenate_46[0][0]             
__________________________________________________________________________________________________
batch_normalization_285 (BatchN (None, 19, 19, 1152) 4608        conv2d_285[0][0]                 
__________________________________________________________________________________________________
lambda_46 (Lambda)              (None, 19, 19, 1152) 0           batch_normalization_285[0][0]    
__________________________________________________________________________________________________
conv2d_286 (Conv2D)             (None, 19, 19, 128)  147584      lambda_46[0][0]                  
__________

__*Parameters(Pytorch deploy): 95,366,600*__  
__*Parameters(MyKeras version): 120,667,144*__

In [32]:
print("\033[1m \033[91m" + 'Should find better explain of network' + "\033[0m")

[1m [91mShould find better explain of network[0m
