# Summary of a NN model from [SMD anomaly detection]((https://www.mdpi.com/1424-8220/18/5/1308)) paper.

Implementation is [here](https://github.com/DongYuls/SMD_Anomaly_Detection). **This model is not for time series data**, but seems to be quite compute intensive.

In [15]:
import os, sys
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath('')), 'python'))
from nns.nns import ModelSummary
from nns.model import Model
from tensorflow.python.keras import models
from tensorflow.python.keras.layers import Conv2D, Conv2DTranspose, BatchNormalization, Activation

In [16]:
class SMDAnomalyDetection(Model):
    """ 
    https://github.com/DongYuls/SMD_Anomaly_Detection/blob/master/model.py
    TODO: refactor conv_block and conv_transpose_block.
    """
    def __init__(self):
        super().__init__('SMDAnomalyDetection')

    @staticmethod
    def conv_block(input_, num_filters, kernel, strides=(1, 1), padding='same', activation='relu', name=None, bn=True):
        x = Conv2D(num_filters, kernel, strides=strides, padding=padding, use_bias=False, name=name + '/conv')(input_)
        if bn:
            x = BatchNormalization(scale=False, name=name + '/bn')(x)
        if activation is not None:
            x = Activation(activation=activation, name=name + '/' + activation)(x)
        return x

    @staticmethod
    def conv_transpose_block(input_, num_filters, kernel, strides=(1, 1), padding='same', activation='relu', name=None, bn=True):
        x = Conv2DTranspose(num_filters, kernel, strides=strides, padding=padding, use_bias=False, name=name + '/conv')(input_)
        if bn:
            x = BatchNormalization(scale=False, name=name + '/bn')(x)
        if activation is not None:
            x = Activation(activation=activation, name=name + '/' + activation)(x)
        return x

    def create(self):
        input_ = Model.Input((1024, 32, 1))
        x = input_
        print("Input shape: {}".format(str(x.shape)))
        # Encoder
        ConvBlock = SMDAnomalyDetection.conv_block
        for idx, num_channels in enumerate([64, 64, 96, 96, 128]):
            x = ConvBlock(x, num_channels, (5, 5), (2, 1), name='enc/conv0{}'.format(idx+1))
        for idx, num_channels in enumerate([128, 160, 160]):
            x = ConvBlock(x, num_channels, (4, 4), (2, 2), name='enc/conv0{}'.format(idx+6))
        x = ConvBlock(x, 192, 3, (2, 2), name='enc/conv09')
        x = ConvBlock(x, 192, 3, (2, 2), name='enc/conv10', activation=None, bn=False)
        print("Code shape: {}".format(str(x.shape)))

        # Decoder
        ConvTBlock = SMDAnomalyDetection.conv_transpose_block
        x = ConvTBlock(x, 192, (3, 3), (2, 2), name='dec/conv01')
        x = ConvTBlock(x, 160, (3, 3), (2, 2), name='dec/conv02')
        for idx, num_channels in enumerate([160, 128, 128]):
            x = ConvTBlock(x, num_channels, (4, 4), (2, 2), name='dec/conv0{}'.format(idx+3))
        for idx, num_channels in enumerate([96, 96, 64, 64]):
            x = ConvTBlock(x, num_channels, (5, 5), (2, 1), name='dec/conv0{}'.format(idx+6))
        x = ConvTBlock(x, 1, (5, 5), (2, 1), name='dec/conv10'.format(idx+6), bn=False, activation=None)
        print("Output shape: {}".format(str(x.shape)))

        return models.Model(input_, x, name=self.name)

In [17]:
#with contextlib.redirect_stderr(None):
ModelSummary(SMDAnomalyDetection().create()).summary()

Input shape: (None, 1024, 32, 1)
Code shape: (None, 1, 1, 192)
Output shape: (None, 1024, 32, 1)
SMDAnomalyDetection
               name    gflops  nparams  params_mb
0   enc/conv01/conv  0.026214     1600   0.006104
1   enc/conv02/conv  0.838861   102400   0.390625
2   enc/conv03/conv  0.629146   153600   0.585938
3   enc/conv04/conv  0.471859   230400   0.878906
4   enc/conv05/conv  0.314573   307200   1.171875
5   enc/conv06/conv  0.067109   262144   1.000000
6   enc/conv07/conv  0.020972   327680   1.250000
7   enc/conv08/conv  0.006554   409600   1.562500
8   enc/conv09/conv  0.001106   276480   1.054688
9   enc/conv10/conv  0.000332   331776   1.265625
10  dec/conv01/conv  0.001327   331776   1.265625
11  dec/conv02/conv  0.004424   276480   1.054688
12  dec/conv03/conv  0.026214   409600   1.562500
13  dec/conv04/conv  0.083886   327680   1.250000
14  dec/conv05/conv  0.268435   262144   1.000000
15  dec/conv06/conv  0.629146   307200   1.171875
16  dec/conv07/conv  0.943718   2