In [1]:

import tensorflow as tf
import numpy as np
import os
from tensorflow.keras.layers import *

In [25]:
class H_x(tf.keras.layers.Layer):

  def __init__(self,num_filter=None,drop_rate = None ):
    super(H_x,self).__init__()
    
    self.bn1 = BatchNormalization()
    self.act1 = Activation('relu')
    self.conv1 = Conv2D(filters = num_filter,kernel_size=3,padding='same')
    self.drop = Dropout(rate = drop_rate)

  def call(self,x):

    x = self.bn1(x)
    x = self.act1(x)
    x = self.conv1(x)
    x = self.drop(x)

    return x

In [26]:
class Transition_Block(tf.keras.layers.Layer):

  def __init__(self,drop_rate=None,projection = None,c_rate = None):

    super(Transition_Block,self).__init__()
    self.projection = projection
    self.c_rate = c_rate
    self.bn = BatchNormalization()
    self.act = Activation('relu')
    self.conv = Conv2D(filters = np.int(self.projection*self.c_rate),kernel_size = 1,strides = 2,padding='same')
    self.drop = Dropout(rate = drop_rate)
    

  def call(self,x):
    
    m = x.shape[3]
    self.projection = m 
    x = self.bn(x)
    x = self.act(x)
    x = self.conv(x)
    x = self.drop(x)


    return x

In [27]:

class dense_block(tf.keras.layers.Layer):

  def __init__(self,num_of_layer=None,num_filters=None,growth_rate=None,drop_rate = None):

    super(dense_block,self).__init__()

    self.num_filter = num_filters

    self.growth_rate = growth_rate


  def call(self,x):

    for i in range(num_of_layer):

      y =  H_x(num_filter=self.num_filter,drop_rate = drop_rate )(x)

      x = Concatenate()([y,x])

      self.num_filter = self.num_filter + self.growth_rate


    return x


In [28]:
class classifier(tf.keras.layers.Layer):

  def __init__(self,num_class=None,fc_unit=None):
    super(classifier,self).__init__()

    self.pool = GlobalAveragePooling2D()
    self.dense = Dense(fc_unit,activation='relu')
    self.dropout = Dropout(0.45)
    self.class_logits = Dense(num_class)

  def call(self,x):

    x = self.pool(x)
    x = self.dense(x)
    x = self.dropout(x)
    x = self.class_logits(x)

    return x

In [34]:
def  Densnet_model(shape=(None,None,None),num_of_block = None,num_of_layer=None,growth_rate=None,num_class = None,fc_unit=None,drop_rate=None):
                   
  X = Input(shape=shape)
  x = Conv2D(32,3,padding = 'same')(X)
  x = BatchNormalization()(x)
  x = ReLU()(x)
  c_rate = 0.75
  growth_rate = growth_rate

  for i in range(num_of_block):

    x = dense_block(num_of_layer=num_of_layer,num_filters=64,growth_rate=growth_rate,drop_rate = drop_rate)(x)

    num_filter = x.shape[3]

    x = Transition_Block(drop_rate=0.4,projection = num_filter,c_rate = c_rate)(x)

  y = classifier(num_class = num_class,fc_unit =fc_unit)(x)

  model = tf.keras.models.Model(inputs = X,outputs = y)

  model.summary()

  return model 



In [35]:
num_of_layer = 4
num_of_block = 7
fc_unit = 1024
num_class = 121
growth_rate = 16 
drop_rate = 0.4
model = Densnet_model(shape=(224,224,3),num_of_block = num_of_block,num_of_layer=num_of_layer,growth_rate=growth_rate,num_class = num_class,fc_unit=fc_unit,drop_rate=drop_rate)


Model: "model_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_15 (InputLayer)        [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv2d_37 (Conv2D)           (None, 224, 224, 32)      896       
_________________________________________________________________
batch_normalization_37 (Batc (None, 224, 224, 32)      128       
_________________________________________________________________
re_lu_14 (ReLU)              (None, 224, 224, 32)      0         
_________________________________________________________________
dense_block_32 (dense_block) (None, 224, 224, 384)     0         
_________________________________________________________________
transition__block_23 (Transi (None, 112, 112, 288)     112416    
_________________________________________________________________
dense_block_33 (dense_block) (None, 112, 112, 640)     0   