In [5]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import mnist
from tensorflow.keras import layers

def get_mnist_model():            
    inputs = keras.Input(shape=(28 * 28,))
    features = layers.Dense(512, activation="relu")(inputs)
    features = layers.Dropout(0.5)(features)
    outputs = layers.Dense(10, activation="softmax")(features)
    model = keras.Model(inputs, outputs)
    return model

(images, labels), (test_images, test_labels) = mnist.load_data()    
images = images.reshape((60000, 28 * 28)).astype("float32") / 255 
test_images = test_images.reshape((10000, 28 * 28)).astype("float32") / 255 
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

class RootMeanSquaredError(keras.metrics.Metric): 
    def __init__(self, name="rmse", **kwargs):                               
        super().__init__(name=name, **kwargs)                                
        self.mse_sum = self.add_weight(name="mse_sum", initializer="zeros")  
        self.total_samples = self.add_weight(                                
            name="total_samples", initializer="zeros", dtype="int32")  

    def update_state(self, y_true, y_pred, sample_weight=None):  
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])   
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)

model = get_mnist_model()
model.compile(optimizer="rmsprop", 
              loss="sparse_categorical_crossentropy", 
              metrics=["accuracy", RootMeanSquaredError()])
model.fit(train_images, 
          train_labels, 
          epochs=3, 
          validation_data=(val_images, val_labels))
test_metrics = model.evaluate(test_images, test_labels)

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9116 - loss: 0.2953 - rmse: 0.3636 - val_accuracy: 0.9596 - val_loss: 0.1478 - val_rmse: 0.2493
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 2ms/step - accuracy: 0.9540 - loss: 0.1597 - rmse: 0.2647 - val_accuracy: 0.9669 - val_loss: 0.1210 - val_rmse: 0.2237
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 2ms/step - accuracy: 0.9632 - loss: 0.1305 - rmse: 0.2381 - val_accuracy: 0.9700 - val_loss: 0.1110 - val_rmse: 0.2142
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 772us/step - accuracy: 0.9744 - loss: 0.0943 - rmse: 0.2015
