In [None]:
import tensorflow as tf
from tensorflow.keras.layers import ReLU,Conv3D,BatchNormalization,Activation,MaxPooling3D,ZeroPadding3D,AveragePooling3D,Dropout,Input,Dense,Flatten
import math

In [None]:
class spatio_temp_conv(tf.keras.layers.Layer):

  def __init__(self,filter=None,strides = (None,None,None),k = None,first_conv=False):

    super(spatio_temp_conv,self).__init__()

    t_stride ,s_stride,s_stride = strides[0],strides[1],strides[2]
    t_kernel ,s_kernel = k,k
    self.first_conv = first_conv

    if self.first_conv:
      
      filter = 45
      self.spatial_conv1 = Conv3D(filters=filter,kernel_size=(1,s_kernel,s_kernel),strides=(1,s_stride,s_stride),padding='same')
      self.bn1 = BatchNormalization()
      self.act1 = ReLU()

      self.temporal_conv1 = Conv3D(filters=filter,kernel_size=(t_kernel,1,1),strides=(t_stride,1,1),padding='same')
      self.bn2 = BatchNormalization()
      self.act2 =ReLU()

    else:

      self.filter = filter


      self.spatial_conv1 = Conv3D(filters=self.filter,kernel_size=(1,s_kernel,s_kernel),strides=(1,s_stride,s_stride),padding='same')
      self.bn1 = BatchNormalization()
      self.act1 = ReLU()

      self.temporal_conv1 = Conv3D(filters=filter,kernel_size=(t_kernel,1,1),strides=(t_stride,1,1),padding='same')
      self.bn2 = BatchNormalization()
      self.act2 =ReLU()

  def call(self,x):
    
    if self.first_conv==False:
      in_feat_map = x.shape[4]
      self.filter = in_feat_map
      x = self.act1(self.bn1((self.spatial_conv1(x))))
      x = self.act2(self.bn2((self.temporal_conv1(x))))

    else:
      x = self.act1(self.bn1((self.spatial_conv1(x))))
      x = self.act2(self.bn2((self.temporal_conv1(x))))

    return x

In [None]:
class spatio_temp_Res_layer(tf.keras.layers.Layer):

  def __init__(self,f=None,k = None,downsample=False):

    super(spatio_temp_Res_layer,self).__init__()


    self.downsample = downsample

    if self.downsample:

      self.down_con = spatio_temp_conv(filter=f,strides = (1,2,2),k = k)

      self.down_bn = BatchNormalization()

      self.down_act1 = ReLU()

      self.conv1 = spatio_temp_conv(filter=f,strides = (1,1,1),k = k)

      self.bn1 = BatchNormalization()

      self.act1 = ReLU()

      self.conv_skip = spatio_temp_conv(filter=f,strides = (1,2,2),k = 1)

      self.skipbn1 = BatchNormalization()

      self.skipact1 = ReLU()

    else:

      self.conv2 = spatio_temp_conv(filter=f,strides = (1,1,1),k = k)

      self.bn2 = BatchNormalization()

      self.act2 = ReLU()

      self.conv3 =  spatio_temp_conv(filter=f,strides = (1,1,1),k = k)

      self.bn3 = BatchNormalization()

      self.act3 = ReLU()


      self.conv4 =  spatio_temp_conv(filter=f,strides = (1,1,1),k = 1)

      self.bn4 = BatchNormalization()

      self.act4 = ReLU()


  def call(self,x):

    if self.downsample:

      res = self.down_act1(self.down_bn(self.down_con(x)))

      res = self.bn1(self.conv1(res))

      x = self.skipbn1(self.conv_skip(x))

      return self.act1(x+res)

    else:

      res = self.act2(self.bn2(self.conv2(x)))

      res = self.bn3(self.conv3(res))

      x = self.bn4(self.conv4(x))

      return self.act4(x+res)


class SpatialTemporalResBlock(tf.keras.layers.Layer):

  def __init__(self,num_of_block=None,filters = None):

    super(SpatialTemporalResBlock,self).__init__()

    self.num_of_block = num_of_block
    self.filters = filters

  def call(self,x):

    for i in range(self.num_of_block):

      x = spatio_temp_Res_layer(f=self.filters,k = 3,downsample=True)(x)

    return x




class R2Plus1D(tf.keras.layers.Layer):

  def __init__(self,num_units=None,num_classes = None):

    super(R2Plus1D,self).__init__()

    self.conv1 = spatio_temp_conv(filter=64,strides = (1,2,2),k = 5)

    self.conv2 = SpatialTemporalResBlock(num_of_block = 2,filters = 128)
    self.conv3 = SpatialTemporalResBlock(num_of_block = 2,filters = 256)
    self.conv4 = SpatialTemporalResBlock(num_of_block = 2,filters = 512)
    self.conv5 = SpatialTemporalResBlock(num_of_block = 2,filters = 1024)

    self.pool = AveragePooling3D()
    self.flatten = Flatten()
    self.dense1 = Dense(num_units)
    self.drop1 = Dropout(0.3)
    self.dense2 = Dense(num_classes)



  def call(self,x):

    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = self.conv4(x)
    x = self.pool(x)

    x = self.flatten(x)
    x = self.dense1(x)
    x = self.drop1(x)
    x = self.dense2(x)

    return x




def r2plus1_model(shape = None,num_classes=None):

  x = Input(shape = shape)

  y = R2Plus1D(num_units=2048,num_classes = num_classes)(x)

  model = tf.keras.models.Model(inputs = x, outputs = y)

  model.summary()

  return model








In [None]:
shape = (16,256,256,3)
r2plus1_model(shape = shape,num_classes = 11)


Model: "model_12"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_31 (InputLayer)        [(None, 16, 256, 256, 3)] 0         
_________________________________________________________________
r2_plus1d_21 (R2Plus1D)      (None, 11)                8439115   
Total params: 8,439,115
Trainable params: 8,438,859
Non-trainable params: 256
_________________________________________________________________


<tensorflow.python.keras.engine.functional.Functional at 0x7f5599b1cd50>

In [None]:
strides = (1,2,2)
y = spatio_temp_Res_layer(f=128,k = 3,downsample=True)(x)
y.shape

TensorShape([1, 8, 8, 8, 32])