In [1]:
import tensorflow as tf
from tensorflow.keras import Model, layers

In [2]:
# layers
def _batch_norm(name):
    def _bn_layer(layer_input):
        return layers.BatchNormalization(
            name=name,
            center=True,
            scale=False,
            epsilon=1e-4
        )(layer_input)

    return _bn_layer


def _conv(name, kernel, stride, filters):
    def _conv_layer(layer_input):
        output = layers.Conv2D(
            name="{}/conv".format(name),
            filters=filters,
            kernel_size=kernel,
            strides=stride,
            padding="same",
            use_bias=False,
            activation=None,
        )(layer_input)
        output = _batch_norm("{}/conv/bn".format(name))(output)
        # output = layers.ReLU(name="{}/relu".format(name))(output)
        output = tf.nn.relu6(output, "{}/relu6".format(name))
        return output

    return _conv_layer


def _separable_conv(name, kernel, stride, filters):
    def _separable_conv_layer(layer_input):
        output = layers.DepthwiseConv2D(
            name="{}/depthwise_conv".format(name),
            kernel_size=kernel,
            strides=stride,
            depth_multiplier=1,
            padding="same",
            use_bias=False,
            activation=None,
        )(layer_input)
        output = _batch_norm("{}/depthwise_conv/bn".format(name))(output)
        output = tf.nn.relu6(output, "{}/depthwise_conv/relu6".format(name))
        output = layers.Conv2D(
            name="{}/pointwise_conv".format(name),
            filters=filters,
            kernel_size=(1, 1),
            strides=1,
            padding="same",
            use_bias=False,
            activation=None,
        )(output)
        output = _batch_norm("{}/pointwise_conv/bn".format(name))(output)
        output = tf.nn.relu6(output, "{}/pointwise_conv/relu6".format(name))
        return output

    return _separable_conv_layer

In [3]:
_YAMNET_LAYER_DEFS = [
    # (layer_function, kernel, stride, num_filters)
    (_conv, [3, 3], 2, 32),
    (_separable_conv, [3, 3], 1, 64),
    (_separable_conv, [3, 3], 2, 128),
    (_separable_conv, [3, 3], 1, 128),
    (_separable_conv, [3, 3], 2, 256),
    (_separable_conv, [3, 3], 1, 256),
    (_separable_conv, [3, 3], 2, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 1, 512),
    (_separable_conv, [3, 3], 2, 1024),
    (_separable_conv, [3, 3], 1, 1024),
]

In [10]:
waveform = layers.Input(
    batch_shape=(1,96,64),
    dtype=tf.float32,
    name="waveform_binary 0"
)

net = layers.Reshape((1,96,64,1))(waveform)

# net = tf.split(
#     reshape,
#     num_or_size_splits=1,
#     axis=0
# )[0]

# for (i, (layer_fun, kernel, stride, filters)) in enumerate(_YAMNET_LAYER_DEFS):
#     net = layer_fun(
#         "layer{}".format(i + 1),
#         kernel,
#         stride,
#         filters,
#     )(net)

# embeddings = layers.GlobalAveragePooling2D()(net)
# logits = layers.Dense(units=521, use_bias=True)(embeddings)
# predictions = layers.Activation(activation="sigmoid")(logits)


In [12]:
base_model = tf.keras.applications.mobilenet.MobileNet()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet/mobilenet_1_0_224_tf.h5


In [13]:
base_model.summary()

Model: "mobilenet_1.00_224"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 224, 224, 3)]     0         
_________________________________________________________________
conv1 (Conv2D)               (None, 112, 112, 32)      864       
_________________________________________________________________
conv1_bn (BatchNormalization (None, 112, 112, 32)      128       
_________________________________________________________________
conv1_relu (ReLU)            (None, 112, 112, 32)      0         
_________________________________________________________________
conv_dw_1 (DepthwiseConv2D)  (None, 112, 112, 32)      288       
_________________________________________________________________
conv_dw_1_bn (BatchNormaliza (None, 112, 112, 32)      128       
_________________________________________________________________
conv_dw_1_relu (ReLU)        (None, 112, 112, 32