<a href="https://colab.research.google.com/github/sefeoglu/AE_Parseval_Network/blob/master/src/notebooks/ResNet_Tensorflow_keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# New ResNet

In [0]:
import tensorflow as tf

In [0]:

from tensorflow.keras.models import Model
from tensorflow.keras.layers import  Input, Add, Activation, Dropout, Flatten, Dense
from tensorflow.keras.layers import Convolution2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.regularizers import  l2
from tensorflow.keras import backend as K
import warnings

warnings.filterwarnings("ignore")
weight_decay = 0.0005

In [0]:
def initial_conv(input):
  x = Convolution2D(16,(3,3),padding="same", kernel_initializer='he_normal',
                    kernel_regularizer=l2(weight_decay),use_bias=False)(input)
  channel_axis = 1 if K.image_data_format() == "channels_first" else -1
  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  return x

In [0]:
def expand_conv(init, base, k, strides = (1,1)):
  x = Convolution2D(base * k, kernel_size=(3,3),padding= "same", strides=strides, kernel_initializer="he_normal", kernel_regularizer=l2(weight_decay),
                    use_bias=False)(init)
  channel_axis = 1 if K.image_data_format() == "channels_first" else -1

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon= 1e-5, gamma_initializer= 'uniform')(x)
  x = Activation('relu')(x)

  x = Convolution2D(base * k, kernel_size=(3,3), padding = 'same', kernel_initializer = 'he_normal',kernel_regularizer=l2(weight_decay),
                    use_bias = False)(x)
  skip = Convolution2D(base * k, kernel_size=(1,1), padding='same', strides=strides, kernel_initializer='he_normal',
                       kernel_regularizer=l2(weight_decay),
                       use_bias = False)(init)
  m = Add()([x, skip])
  return m

In [0]:
def conv1_block(input, k=1, dropout = 0.0):
  init = input
  
  channel_axis = 1 if K.image_data_format() == "channels_first" else -1

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(input)
  x = Activation('relu')(x)
  x = Convolution2D(16 * k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal', kernel_regularizer=l2(weight_decay),
                    use_bias=False)(x)
  if dropout > 0.0: x = Dropout(dropout)(x)

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  x = Convolution2D(16 * k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal',kernel_regularizer=l2(weight_decay),
                    use_bias = False)(x)
  m = Add()([init, x])
  return m

In [0]:
def  conv2_block(input, k=1, dropout = 0.0):
  init = input

  channel_axis = 1 

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(input)
  x = Activation('relu')(x)
  x = Convolution2D(32 * k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal',
                    kernel_regularizer = l2(weight_decay), use_bias = False)(x)
  
  if dropout > 0.0: x = Dropout(dropout)(x)

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  x = Convolution2D(32*k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal',
                    kernel_regularizer= l2(weight_decay),
                    use_bias = False)(x)
  
  m = Add()([init, x])
  return m


In [0]:
def conv3_block(input, k=1, dropout=0.0):
  init = input

  channel_axis = 1 

  x = BatchNormalization(axis=channel_axis, momentum= 0.1, epsilon = 1e-5, gamma_initializer='uniform' )(input)
  x = Activation('relu')(x)
  x = Convolution2D(64 * k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal',
                    kernel_regularizer= l2(weight_decay),
                    use_bias = False)(x)
  
  if dropout > 0.0: x = Dropout(dropout)(x)

  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  x = Convolution2D(64 * k, kernel_size=(3,3), padding='same', kernel_initializer='he_normal',
                    kernel_regularizer=l2(weight_decay),
                    use_bias=False)(x)
  m = Add()([init, x])
  return m


In [0]:
def create_wide_residual_network(input_dim, nb_classes=4, N=2, k=2, dropout = 0.0, verbose=1):
  channel_axis = 1 if K.image_data_format() == "channels_first" else -1

  ip = Input(shape=input_dim)

  x = initial_conv(ip)
  nb_conv = 4
  x = expand_conv(x, 16, k)
  nb_conv +=2

  for i in range(N-1):
    x = conv1_block(x, k, dropout)
    nb_conv +=2
  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)

  x = expand_conv(x, 32, k, strides=(1, 1))
  nb_conv += 2
  for i in range(N - 1):
    x = conv2_block(x, k, dropout)
    nb_conv += 2
  
  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  
  x = expand_conv(x, 64, k, strides=(1, 1))
  nb_conv += 2
  
  for i in range(N - 1):
    x = conv3_block(x, k, dropout)
    nb_conv += 2
    
  x = BatchNormalization(axis=channel_axis, momentum=0.1, epsilon=1e-5, gamma_initializer='uniform')(x)
  x = Activation('relu')(x)
  
  x = AveragePooling2D((8, 8))(x)
  x = Flatten()(x)
  
  x = Dense(nb_classes, kernel_regularizer=l2(weight_decay), activation='softmax')(x)
  
  model = Model(ip, x)
  
  if verbose: print("Wide Residual Network-%d-%d created." % (nb_conv, k))
  
  return model





In [57]:
if __name__ == "__main__":
    from tensorflow.keras.utils import plot_model
    from tensorflow.keras.layers import Input
    from tensorflow.keras.models import Model

    init = (68, 100,1)
    wrn28_10 = create_wide_residual_network(init, nb_classes=4, N=4, k=10, dropout = 0.0, verbose=1)
  

Wide Residual Network-28-10 created.
