In [None]:
    ")\n",
        "\n",
        "model.summary()\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Callbacks and training\n",
        "os.makedirs(os.path.join('..','models'), exist_ok=True)\n",
        "checkpoint_path = os.path.join('..','models', f'{model_choice}_best.h5')\n",
        "\n",
        "callbacks = [\n",
        "    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=6, restore_best_weights=True, verbose=1),\n",
        "    tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_accuracy', save_best_only=True, verbose=1),\n",
        "    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3, min_lr=1e-7, verbose=1)\n",
        "]\n",
        "\n",
        "epochs = 10\n",
        "start_time = time.time()\n",
        "history = model.fit(\n",
        "    train_ds,\n",
        "    validation_data=valid_ds,\n",
        "    epochs=epochs,\n",
        "    callbacks=callbacks\n",
        ")\n",
        "print(f'Training time: {time.time() - start_time:.1f}s')\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# Save final model and plot training curves\n",
        "final_path = os.path.join('..','models', f'{model_choice}_final.h5')\n",
        "model.save(final_path)\n",
        "print('Saved model to', final_path)\n",
        "\n",
        "def plot_history(history, out_path=None):\n",
        "    h = history.history\n",
        "    plt.figure(figsize=(12,4))\n",
        "    plt.subplot(1,2,1)\n",
        "    plt.plot(h.get('accuracy', []), label='train_acc')\n",
        "    plt.plot(h.get('val_accuracy', []), label='val_acc')\n",
        "    plt.title('Accuracy')\n",
        "    plt.legend()\n",
        "    plt.subplot(1,2,2)\n",
        "    plt.plot(h.get('loss', []), label='train_loss')\n",
        "    plt.plot(h.get('val_loss', []), label='val_loss')\n",
        "    plt.title('Loss')\n",
        "    plt.legend()\n",
        "    plt.tight_layout()\n",
        "    if out_path:\n",
        "        plt.savefig(out_path)\n",
        "    plt.show()\n",
        "\n",
        "plot_out = os.path.join('..','models', f'{model_choice}_training.png')\n",
        "plot_history(history, out_path=plot_out)\n",
        "print('Saved training plot to', plot_out)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Notes:\n",
        "- If your preprocessing.create_data_generators returns class names, the notebook will print them.\n",
        "- Run this notebook in VS Code (Select Python kernel) or Jupyter. Paths assume notebook lives in notebooks/.\n",
        "- Adjust model_choice, batch size and epochs as needed."
      ]
    }
  ]
}