diff --git a/site/en/tutorials/keras/save_and_load.ipynb b/site/en/tutorials/keras/save_and_load.ipynb index 6e48a08a4a6..7621bea1aaa 100644 --- a/site/en/tutorials/keras/save_and_load.ipynb +++ b/site/en/tutorials/keras/save_and_load.ipynb @@ -385,12 +385,17 @@ "\n", "batch_size = 32\n", "\n", + "# Calculate the number of batches per epoch\n", + "import math\n", + "n_batches = len(train_images) / batch_size\n", + "n_batches = math.ceil(n_batches) # round up the number of batches to the nearest whole integer\n", + "\n", "# Create a callback that saves the model's weights every 5 epochs\n", "cp_callback = tf.keras.callbacks.ModelCheckpoint(\n", " filepath=checkpoint_path, \n", " verbose=1, \n", " save_weights_only=True,\n", - " save_freq=5*batch_size)\n", + " save_freq=5*n_batches)\n", "\n", "# Create a new model instance\n", "model = create_model()\n",