In [40]:
from keras import backend as K
from keras.layers import Conv2D, Add, Input, BatchNormalization, Activation, MaxPooling2D, GlobalAveragePooling2D,SeparableConv2D, Dense
from keras.models import Model

def shortcut(conv, residual):
    conv_shape = K.int_shape(conv)
    residual_shape = K.int_shape(residual)
    if conv_shape != residual_shape:
        residual = Conv2D(filters=conv_shape[3], kernel_size=(1, 1), strides=(2, 2))(residual)
    return Add()([conv, residual])

def block(x,filter,stride=1,cardinality=32):
    multiplier = filter // cardinality
    conv = Conv2D(filters=filter, kernel_size=(1, 1), strides=(stride, stride), padding="same", kernel_initializer='he_normal')(x)
    conv = BatchNormalization()(conv)
    conv = Activation("relu")(conv)
    conv = SeparableConv2D(filters=filter, kernel_size=(3, 3), strides=(1, 1), padding="same", depth_multiplier=multiplier, kernel_initializer='he_normal')(conv)
    conv = BatchNormalization()(conv)
    conv = Activation("relu")(conv)
    conv = Conv2D(filters=filter, kernel_size=(1, 1), strides=(1, 1), padding="same", kernel_initializer='he_normal')(conv)
    conv = BatchNormalization()(conv)
    conv = shortcut(conv, x)
    conv = Activation("relu")(conv)
    return conv


class ResNeXt50:
    def __init__(self, input_shape, nb_classes):
        self.input_shape = input_shape
        self.nb_classes = nb_classes
        self.model = self.make_model()

    def make_model(self):
        inputs = Input(self.input_shape)
        x = Conv2D(filters=64, kernel_size=(7, 7), strides=(2, 2), padding="same", kernel_initializer='he_normal')(inputs)
        x = BatchNormalization()(x)
        x = Activation('relu')(x)
        x = MaxPooling2D(pool_size=(3, 3), strides=(2, 2), padding="same")(x)

        x = block(x, 64)
        x = block(x, 64)
        x = block(x, 64)

        x = block(x, 128, stride=2)
        x = block(x, 128)
        x = block(x, 128)
        x = block(x, 128)

        x = block(x, 256, stride=2)
        x = block(x, 256)
        x = block(x, 256)
        x = block(x, 256)
        x = block(x, 256)
        x = block(x, 256)

        x = block(x, 512, stride=2)
        x = block(x, 512)
        x = block(x, 512)

        x = GlobalAveragePooling2D()(x)
        output = Dense(units=self.nb_classes, activation='softmax')(x)
        ResNeXtModel = Model(inputs=inputs, outputs=output)
        return ResNeXtModel

def build(input_shape, nb_classes):
    return ResNeXt50(input_shape, nb_classes).model
    
#model = ResNeXt50(input_shape=(224,224,3), nb_classes=1000).model

In [41]:
Model = ResNeXt50(input_shape=(1024,1024,1), nb_classes=1000).model
Model.build(input_shape = (None, 1024, 1024, 1))
Model.summary()

Model: "model_5"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_29 (InputLayer)           [(None, 1024, 1024,  0                                            
__________________________________________________________________________________________________
conv2d_621 (Conv2D)             (None, 512, 512, 64) 3200        input_29[0][0]                   
__________________________________________________________________________________________________
batch_normalization_828 (BatchN (None, 512, 512, 64) 256         conv2d_621[0][0]                 
__________________________________________________________________________________________________
activation_817 (Activation)     (None, 512, 512, 64) 0           batch_normalization_828[0][0]    
____________________________________________________________________________________________