# Writing your own metrics

## Listing 7.18 Implementing a custom metric by subclassing the Metric class

In [22]:
import tensorflow as tf
from tensorflow import keras

In [23]:
class RootMeanSquaredError(keras.metrics.Metric):
    def __init__(self, name='rmse', **kwargs):
        super().__init__(name=name, **kwargs)
        self.mse_sum = self.add_weight(name='mse_sum', initializer='zeros')
        self.total_samples = self.add_weight(name='total_samples', initializer='zeros', dtype='int32')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_true = tf.one_hot(y_true, depth=tf.shape(y_pred)[1])
        mse = tf.reduce_sum(tf.square(y_true - y_pred))
        self.mse_sum.assign_add(mse)
        num_samples = tf.shape(y_pred)[0]
        self.total_samples.assign_add(num_samples)

    def result(self):
        return tf.sqrt(self.mse_sum / tf.cast(self.total_samples, tf.float32))

    def reset_state(self):
        self.mse_sum.assign(0.)
        self.total_samples.assign(0)

In [24]:
def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28, ))
    features = keras.layers.Dense(512, activation='relu')(inputs)
    features = keras.layers.Dropout(0.5)(features)
    outputs = keras.layers.Dense(10, activation='softmax')(features)
    model = keras.Model(inputs, outputs)
    return model

In [25]:
from keras.datasets import mnist
(images, labels), (test_images, test_labels) = mnist.load_data()

In [26]:
images = images.reshape((60000, 28 * 28)).astype('float32') / 255

In [27]:
test_images = test_images.reshape((10000, 28 * 28)).astype('float32') / 255

In [28]:
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

In [29]:
model = get_mnist_model()

In [30]:
model.compile(optimizer='rmsprop',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy', RootMeanSquaredError()])

In [31]:
model.fit(train_images, train_labels,
         epochs=3,
         validation_data=(val_images, val_labels))

Epoch 1/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.8661 - loss: 0.4400 - rmse: 0.4352 - val_accuracy: 0.9565 - val_loss: 0.1484 - val_rmse: 0.2574
Epoch 2/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4ms/step - accuracy: 0.9511 - loss: 0.1659 - rmse: 0.2730 - val_accuracy: 0.9658 - val_loss: 0.1202 - val_rmse: 0.2272
Epoch 3/3
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - accuracy: 0.9617 - loss: 0.1313 - rmse: 0.2417 - val_accuracy: 0.9736 - val_loss: 0.0991 - val_rmse: 0.2024


<keras.src.callbacks.history.History at 0x27324b11dc0>

In [33]:
test_metrics = model.evaluate(test_images, test_labels)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 3ms/step - accuracy: 0.9718 - loss: 0.1006 - rmse: 0.2094
