In [1]:
import tensorflow as tf
import tensorflow_addons as tfa

In [2]:
mnist = tf.keras.datasets.mnist

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [3]:
model = tf.keras.models.Sequential([
    # Reshape into "channels last" setup.
    tf.keras.layers.Reshape((28, 28, 1), input_shape=(28, 28)),
    tf.keras.layers.Conv2D(filters=10, kernel_size=(3, 3), data_format="channels_last"),
    # Groupnorm Layer
    tfa.layers.GroupNormalization(groups=5, axis=3),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation='softmax')
])

In [5]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.fit(x_test, y_test,epochs=5)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fa377dd32b0>

In [6]:
model.evaluate(x_test,  y_test, verbose=2)

313/313 - 1s - loss: 0.0227 - accuracy: 0.9933


[0.022684184834361076, 0.9933000206947327]

In [7]:
!mkdir -p saved_model
model.save('saved_model/my_model')

Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
Instructions for updating:
This property should not be used in TensorFlow 2.0, as updates are applied automatically.
INFO:tensorflow:Assets written to: saved_model/my_model/assets


In [8]:
normalization_model= tf.keras.models.load_model('saved_model/my_model')
normalization_model.summary()


Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
reshape (Reshape)            (None, 28, 28, 1)         0         
_________________________________________________________________
conv2d (Conv2D)              (None, 26, 26, 10)        100       
_________________________________________________________________
group_normalization (GroupNo (None, 26, 26, 10)        20        
_________________________________________________________________
flatten (Flatten)            (None, 6760)              0         
_________________________________________________________________
dense (Dense)                (None, 128)               865408    
_________________________________________________________________
dropout (Dropout)            (None, 128)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 10)                1