In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
keras.callbacks.ModelCheckpoint

keras.src.callbacks.model_checkpoint.ModelCheckpoint

In [4]:
keras.callbacks.EarlyStopping

keras.src.callbacks.early_stopping.EarlyStopping

In [5]:
keras.callbacks.LearningRateScheduler

keras.src.callbacks.learning_rate_scheduler.LearningRateScheduler

In [6]:
keras.callbacks.ReduceLROnPlateau

keras.src.callbacks.reduce_lr_on_plateau.ReduceLROnPlateau

In [7]:
keras.callbacks.CSVLogger

keras.src.callbacks.csv_logger.CSVLogger

In [9]:
def get_mnist_model():
    inputs = keras.Input(shape=(28 * 28, ))
    features = keras.layers.Dense(512, activation='relu')(inputs)
    features = keras.layers.Dropout(0.5)(features)
    outputs = keras.layers.Dense(10, activation='softmax')(features)
    model = keras.Model(inputs, outputs)
    return model

In [10]:
from keras.datasets import mnist
(images, labels), (test_images, test_labels) = mnist.load_data()

In [11]:
images = images.reshape((60000, 28 * 28)).astype('float32') / 255

In [12]:
test_images = test_images.reshape((10000, 28 * 28)).astype('float32') / 255

In [13]:
train_images, val_images = images[10000:], images[:10000]
train_labels, val_labels = labels[10000:], labels[:10000]

In [15]:
callbacks_list = [
    keras.callbacks.EarlyStopping(
        monitor = 'val_accuracy',
        patience = 2,
    ),
    keras.callbacks.ModelCheckpoint(
        filepath = 'checkpoint_path.keras',
        monitor = 'val_loss',
        save_best_only=True,
    ),
]

In [16]:
model = get_mnist_model()

In [17]:
model.compile(optimizer='rmsprop',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'],)

In [18]:
model.fit(train_images, train_labels,
         epochs=10,
         callbacks=callbacks_list,
         validation_data=(val_images, val_labels))

Epoch 1/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m9s[0m 5ms/step - accuracy: 0.8609 - loss: 0.4560 - val_accuracy: 0.9552 - val_loss: 0.1497
Epoch 2/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 4ms/step - accuracy: 0.9507 - loss: 0.1675 - val_accuracy: 0.9691 - val_loss: 0.1128
Epoch 3/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9626 - loss: 0.1378 - val_accuracy: 0.9729 - val_loss: 0.1101
Epoch 4/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9678 - loss: 0.1092 - val_accuracy: 0.9756 - val_loss: 0.0945
Epoch 5/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 4ms/step - accuracy: 0.9709 - loss: 0.1032 - val_accuracy: 0.9749 - val_loss: 0.1021
Epoch 6/10
[1m1563/1563[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 3ms/step - accuracy: 0.9747 - loss: 0.0932 - val_accuracy: 0.9777 - val_loss: 0.0910
Epoch 7/10
[1m1

<keras.src.callbacks.history.History at 0x14cd0569ac0>

In [21]:
model = keras.models.load_model('checkpoint_path.keras')