From e616f48bd97b9a06e099969fe0c844072f6b8f96 Mon Sep 17 00:00:00 2001 From: Ruomei Yan Date: Wed, 11 Mar 2020 18:02:42 +0000 Subject: [PATCH] Create Jupyter notebooks for clustering --- .../clustering_comprehensive_guide.ipynb | 471 ++++++++++++ .../guide/clustering/clustering_example.ipynb | 672 ++++++++++++++++++ 2 files changed, 1143 insertions(+) create mode 100755 tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb create mode 100755 tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb new file mode 100755 index 000000000..c617630ee --- /dev/null +++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb @@ -0,0 +1,471 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "celltoolbar": "Raw Cell Format", + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + }, + "colab": { + "name": "clustering_comprehensive_guide.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "826IBSWMN4rr", + "colab_type": "text" + }, + "source": [ + "**Copyright 2020 The TensorFlow Authors.**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ITj3u97-tNR7", + "colab_type": "code", + "colab": {} + }, + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BYwfpc4wN4rt", + "colab_type": "text" + }, + "source": [ + "# Weight clustering comprehensive guide" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IFva_Ed5N4ru", + "colab_type": "text" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tidmcl3sN4rv", + "colab_type": "text" + }, + "source": [ + "Welcome to the comprehensive guide for *weight clustering*, part of the TensorFlow Model Optimization toolkit.\n", + "\n", + "This page documents various use cases and shows how to use the API for each one. Once you know which APIs you need, find the parameters and the low-level details in the [API docs](https://www.tensorflow.org/model_optimization/api_docs/python/tfmot/clustering):\n", + "\n", + "* If you want to see the benefits of weight clustering and what's supported, check the [overview](https://www.tensorflow.org/model_optimization/guide/clustering).\n", + "* For a single end-to-end example, see the [weight clustering example](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_example.ipynb).\n", + "\n", + "In this guide, the following use cases are covered:\n", + "* Define a clustered model.\n", + "* Checkpoint and deserialize a clustered model.\n", + "* Improve the accuracy of the clustered model.\n", + "* For deployment only, you must take steps to see compression benefits.\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RRtKxbo8N4rv", + "colab_type": "text" + }, + "source": [ + "## Setup\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true, + "id": "08dJRvOqN4rw", + "colab_type": "code", + "colab": {} + }, + "source": [ + "! pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple tensorflow-model-optimization==0.4.0.dev2\n", + "\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "import tempfile\n", + "import os\n", + "import tensorflow_model_optimization as tfmot\n", + "\n", + "input_dim = 20\n", + "output_dim = 20\n", + "x_train = np.random.randn(1, input_dim).astype(np.float32)\n", + "y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=output_dim)\n", + "\n", + "def setup_model():\n", + " model = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(input_dim, input_shape=[input_dim]),\n", + " tf.keras.layers.Flatten()\n", + " ])\n", + " return model\n", + "\n", + "def train_model(model):\n", + " model.compile(\n", + " loss=tf.keras.losses.categorical_crossentropy,\n", + " optimizer='adam',\n", + " metrics=['accuracy']\n", + " )\n", + " model.summary()\n", + " model.fit(x_train, y_train)\n", + " return model\n", + "\n", + "def save_model_weights(model):\n", + " _, pretrained_weights = tempfile.mkstemp('.h5')\n", + " model.save_weights(pretrained_weights)\n", + " return pretrained_weights\n", + "\n", + "def setup_pretrained_weights():\n", + " model= setup_model()\n", + " model = train_model(model)\n", + " pretrained_weights = save_model_weights(model)\n", + " return pretrained_weights\n", + "\n", + "def setup_pretrained_model():\n", + " model = setup_model()\n", + " pretrained_weights = setup_pretrained_weights()\n", + " model.load_weights(pretrained_weights)\n", + " return model\n", + "\n", + "def save_model_file(model):\n", + " _, keras_file = tempfile.mkstemp('.h5') \n", + " model.save(keras_file, include_optimizer=False)\n", + " return keras_file\n", + " \n", + "def get_gzipped_model_size(model):\n", + " # It returns the size of the gzipped model in bytes.\n", + " import os\n", + " import zipfile\n", + "\n", + " keras_file = save_model_file(model)\n", + "\n", + " _, zipped_file = tempfile.mkstemp('.zip')\n", + " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", + " f.write(keras_file)\n", + " return os.path.getsize(zipped_file)\n", + "\n", + "setup_model()\n", + "pretrained_weights = setup_pretrained_weights()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ARd37qONN4rz", + "colab_type": "text" + }, + "source": [ + "# Define a clustered model\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zHB3pkU3N4r0", + "colab_type": "text" + }, + "source": [ + "### Cluster a whole model (sequential and functional)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ig-il1lmN4r1", + "colab_type": "text" + }, + "source": [ + "**Tips** for better model accuracy:\n", + "\n", + "* You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy. \n", + "* In some cases, clustering certain layers has a detrimental effect on model accuracy. Check \"Cluster some layers\" to see how to skip clustering the layers that affect accuracy the most.\n", + "\n", + "To cluster all layers, apply `tfmot.clustering.keras.cluster_weights` to the model.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true, + "id": "29g7OADjN4r1", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import tensorflow_model_optimization as tfmot\n", + "\n", + "cluster_weights = tfmot.clustering.keras.cluster_weights\n", + "CentroidInitialization = tfmot.clustering.keras.CentroidInitialization\n", + "\n", + "clustering_params = {\n", + " 'number_of_clusters': 3,\n", + " 'cluster_centroids_init': CentroidInitialization.DENSITY_BASED\n", + "}\n", + "\n", + "model = setup_model()\n", + "model.load_weights(pretrained_weights)\n", + "\n", + "clustered_model = cluster_weights(model, **clustering_params)\n", + "\n", + "clustered_model.summary()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zEOHK4OON4r7", + "colab_type": "text" + }, + "source": [ + "### Cluster some layers (sequential and functional models)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ENscQ7ZWN4r8", + "colab_type": "text" + }, + "source": [ + "**Tips** for better model accuracy:\n", + "\n", + "* You must pass a pre-trained model with acceptable accuracy to this API. Training models from scratch with clustering results in subpar accuracy.\n", + "* Cluster later layers with more redundant parameters (e.g. `tf.keras.layers.Dense`, `tf.keras.layers.Conv2D`), as opposed to the early layers.\n", + "* Freeze early layers prior to the clustered layers during fine-tuning. Treat the number of frozen layers as a hyperparameter. Empirically, freezing most early layers is ideal for the current clustering API.\n", + "* Avoid clustering critical layers (e.g. attention mechanism).\n", + "\n", + "**More**: the `tfmot.clustering.keras.cluster_weights` API docs provide details on how to vary the clustering configuration per layer." + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true, + "id": "IqBdl3uJN4r_", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Create a base model\n", + "base_model = setup_model()\n", + "base_model.load_weights(pretrained_weights)\n", + "\n", + "# Helper function uses `cluster_weights` to make only \n", + "# the Dense layers train with clustering\n", + "def apply_clustering_to_dense(layer):\n", + " if isinstance(layer, tf.keras.layers.Dense):\n", + " return cluster_weights(layer, **clustering_params)\n", + " return layer\n", + "\n", + "# Use `tf.keras.models.clone_model` to apply `apply_clustering_to_dense` \n", + "# to the layers of the model.\n", + "clustered_model = tf.keras.models.clone_model(\n", + " base_model,\n", + " clone_function=apply_clustering_to_dense,\n", + ")\n", + "\n", + "clustered_model.summary()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hN0DgpvD5Add", + "colab_type": "text" + }, + "source": [ + "# Checkpoint and deserialize a clustered model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "hfji5KWN6XCF", + "colab_type": "text" + }, + "source": [ + "**Your use case:** this code is only needed for the HDF5 model format (not HDF5 weights or other formats)." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "w7P67mPk6RkQ", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Define the model.\n", + "base_model = setup_model()\n", + "base_model.load_weights(pretrained_weights)\n", + "clustered_model = cluster_weights(base_model, **clustering_params)\n", + "\n", + "# Save or checkpoint the model.\n", + "_, keras_model_file = tempfile.mkstemp('.h5')\n", + "clustered_model.save(keras_model_file, include_optimizer=True)\n", + "\n", + "# `cluster_scope` is needed for deserializing HDF5 models.\n", + "with tfmot.clustering.keras.cluster_scope():\n", + " loaded_model = tf.keras.models.load_model(keras_model_file)\n", + "\n", + "loaded_model.summary()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cUv-scK-N4sN", + "colab_type": "text" + }, + "source": [ + "# Improve the accuracy of the clustered model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-fZZopDBN4sO", + "colab_type": "text" + }, + "source": [ + "For your specific use case, there are tips you can consider:\n", + "\n", + "* Centroid initialization plays a key role in the final optimized model accuracy. In general, linear initialization outperforms density and random initialization since it does not tend to miss large weights. However, density initialization has been observed to give better accuracy for the case of using very few clusters on weights with bimodal distributions.\n", + "\n", + "* Set a learning rate that is lower than the one used in training when fine-tuning the clustered model.\n", + "\n", + "* For general ideas to improve model accuracy, look for tips for your use case(s) under \"Define a clustered model\"." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4DXw7YbyN4sP", + "colab_type": "text" + }, + "source": [ + "# Deployment" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5Y5zLfPzN4sQ", + "colab_type": "text" + }, + "source": [ + "## Export model with size compression" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wX4OrHD9N4sQ", + "colab_type": "text" + }, + "source": [ + "**Common mistake**: both `strip_clustering` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ZvuiCBsVN4sR", + "colab_type": "code", + "colab": {} + }, + "source": [ + "model = setup_model()\n", + "clustered_model = cluster_weights(model, **clustering_params)\n", + "\n", + "clustered_model.compile(\n", + " loss=tf.keras.losses.categorical_crossentropy,\n", + " optimizer='adam',\n", + " metrics=['accuracy']\n", + ")\n", + "\n", + "clustered_model.fit(\n", + " x_train,\n", + " y_train\n", + ")\n", + "\n", + "final_model = tfmot.clustering.keras.strip_clustering(clustered_model)\n", + "\n", + "print(\"final model\")\n", + "final_model.summary()\n", + "\n", + "print(\"\\n\")\n", + "print(\"Size of gzipped clustered model without stripping: %.2f bytes\" \n", + " % (get_gzipped_model_size(clustered_model)))\n", + "print(\"Size of gzipped clustered model with stripping: %.2f bytes\" \n", + " % (get_gzipped_model_size(final_model)))" + ], + "execution_count": null, + "outputs": [] + } + ] +} \ No newline at end of file diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb new file mode 100755 index 000000000..6647eeff2 --- /dev/null +++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb @@ -0,0 +1,672 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "clustering_example.ipynb", + "provenance": [], + "collapsed_sections": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.7" + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "m7hbib3bSGO9", + "colab_type": "text" + }, + "source": [ + "**Copyright 2020 The TensorFlow Authors.**" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "both", + "id": "mEE8NFIMSGO-", + "colab_type": "code", + "colab": {} + }, + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SyiSRgdtSGPC", + "colab_type": "text" + }, + "source": [ + "# Weight clustering in Keras example" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kW3os956SGPD", + "colab_type": "text" + }, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dKnJyAaASGPD", + "colab_type": "text" + }, + "source": [ + "## Overview\n", + "\n", + "Welcome to the end-to-end example for *weight clustering*, part of the TensorFlow Model Optimization Toolkit.\n", + "\n", + "### Other pages\n", + "\n", + "For an introduction to what weight clustering is and to determine if you should use it (including what's supported), see the [overview page](https://www.tensorflow.org/model_optimization/guide/clustering).\n", + "\n", + "To quickly find the APIs you need for your use case (beyond fully clustering a model with 16 clusters), see the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/comprehensive_guide.md).\n", + "\n", + "### Contents\n", + "\n", + "In the tutorial, you will:\n", + "\n", + "1. Train a `tf.keras` model for the MNIST dataset from scratch.\n", + "2. Fine-tune the model by applying the weight clustering API and see the accuracy.\n", + "3. Create a 6x smaller TF and TFLite models from clustering.\n", + "4. Create a 8x smaller TFLite model from combining weight clustering and post-training quantization.\n", + "5. See the persistence of accuracy from TF to TFLite." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RgcQznnZSGPE", + "colab_type": "text" + }, + "source": [ + "## Setup\n", + "\n", + "You can run this Jupyter Notebook in your local [virtualenv](https://www.tensorflow.org/install/pip?lang=python3#2.-create-a-virtual-environment-recommended) or [colab](https://colab.sandbox.google.com/). For details of setting up dependencies, please refer to the [installation guide](https://www.tensorflow.org/model_optimization/guide/install). " + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true, + "id": "3asgXMqnSGPE", + "colab_type": "code", + "colab": {} + }, + "source": [ + "! pip install --index-url https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple tensorflow-model-optimization==0.4.0.dev2" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "gL6JiLXkSGPI", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "\n", + "import numpy as np\n", + "import tempfile\n", + "import zipfile\n", + "import os" + ], + "execution_count": 34, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dKzOfl5FSGPL", + "colab_type": "text" + }, + "source": [ + "## Train a tf.keras model for MNIST without clustering" + ] + }, + { + "cell_type": "code", + "metadata": { + "scrolled": true, + "id": "w7Fd6jZ7SGPL", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Load MNIST dataset\n", + "mnist = keras.datasets.mnist\n", + "(train_images, train_labels), (test_images, test_labels) = mnist.load_data()\n", + "\n", + "# Normalize the input image so that each pixel value is between 0 to 1.\n", + "train_images = train_images / 255.0\n", + "test_images = test_images / 255.0\n", + "\n", + "# Define the model architecture.\n", + "model = keras.Sequential([\n", + " keras.layers.InputLayer(input_shape=(28, 28)),\n", + " keras.layers.Reshape(target_shape=(28, 28, 1)),\n", + " keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),\n", + " keras.layers.MaxPooling2D(pool_size=(2, 2)),\n", + " keras.layers.Flatten(),\n", + " keras.layers.Dense(10)\n", + "])\n", + "\n", + "# Train the digit classification model\n", + "model.compile(optimizer='adam',\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " metrics=['accuracy'])\n", + "\n", + "model.fit(\n", + " train_images,\n", + " train_labels,\n", + " validation_split=0.1,\n", + " epochs=10\n", + ")" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rBOQ8MeESGPO", + "colab_type": "text" + }, + "source": [ + "### Evaluate the baseline model and save it for later usage" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "HYulekocSGPP", + "colab_type": "code", + "colab": {} + }, + "source": [ + "_, baseline_model_accuracy = model.evaluate(\n", + " test_images, test_labels, verbose=0)\n", + "\n", + "print('Baseline test accuracy:', baseline_model_accuracy)\n", + "\n", + "_, keras_file = tempfile.mkstemp('.h5')\n", + "print('Saving model to: ', keras_file)\n", + "tf.keras.models.save_model(model, keras_file, include_optimizer=False)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "cWPgcnjKSGPR", + "colab_type": "text" + }, + "source": [ + "## Fine-tune the pre-trained model with clustering" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Y2wKK7w9SGPS", + "colab_type": "text" + }, + "source": [ + "Apply the `cluster_weights()` API to a whole pre-trained model to demonstrate its effectiveness in reducing the model size after applying zip while keeping decent accuracy. For how best to balance the accuracy and compression rate for your use case, please refer to the per layer example in the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/comprehensive_guide.md).\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ea40z522SGPT", + "colab_type": "text" + }, + "source": [ + "### Define the model and apply the clustering API" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7aOB5vjOZMTS", + "colab_type": "text" + }, + "source": [ + "Before you pass the model to the clustering API, make sure it is trained and shows some acceptable accuracy." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "OzqKKt0mSGPT", + "colab_type": "code", + "colab": {} + }, + "source": [ + "import tensorflow_model_optimization as tfmot\n", + "\n", + "cluster_weights = tfmot.clustering.keras.cluster_weights\n", + "CentroidInitialization = tfmot.clustering.keras.CentroidInitialization\n", + "\n", + "clustering_params = {\n", + " 'number_of_clusters': 16,\n", + " 'cluster_centroids_init': CentroidInitialization.LINEAR\n", + "}\n", + "\n", + "# Cluster a whole model\n", + "clustered_model = cluster_weights(model, **clustering_params)\n", + "\n", + "# Use smaller learning rate for fine-tuning clustered model\n", + "opt = tf.keras.optimizers.Adam(learning_rate=1e-5)\n", + "\n", + "clustered_model.compile(\n", + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", + " optimizer=opt,\n", + " metrics=['accuracy'])\n", + "\n", + "clustered_model.summary()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ev4MyClmSGPW", + "colab_type": "text" + }, + "source": [ + "### Fine-tune the model and evaluate the accuracy against baseline" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vQoy9CcASGPX", + "colab_type": "text" + }, + "source": [ + "Fine-tune the model with clustering for 1 epoch." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "jn29-coXSGPX", + "colab_type": "code", + "colab": {} + }, + "source": [ + "# Fine-tune model\n", + "clustered_model.fit(\n", + " train_images,\n", + " train_labels,\n", + " batch_size=500,\n", + " epochs=1,\n", + " validation_split=0.1)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dvaZKoxtTORx", + "colab_type": "text" + }, + "source": [ + "For this example, there is minimal loss in test accuracy after clustering, compared to the baseline." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "bE7MxpWLTaQ1", + "colab_type": "code", + "colab": {} + }, + "source": [ + "_, clustered_model_accuracy = clustered_model.evaluate(\n", + " test_images, test_labels, verbose=0)\n", + "\n", + "print('Baseline test accuracy:', baseline_model_accuracy)\n", + "print('Clustered test accuracy:', clustered_model_accuracy)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VXfPMa6ISGPd", + "colab_type": "text" + }, + "source": [ + "## Create **6x** smaller models from clustering" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1zr_QIhcUeuC", + "colab_type": "text" + }, + "source": [ + "Both `strip_clustering` and applying a standard compression algorithm (e.g. via gzip) are necessary to see the compression benefits of clustering. \n", + "\n", + "First, create a compressible model for TensorFlow. Here, `strip_clustering` removes all variables (e.g. `tf.Variable` for storing the cluster centroids and the indices) that clustering only needs during training, which would otherwise add to model size during inference." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "4h6tSvMzSGPd", + "colab_type": "code", + "colab": {} + }, + "source": [ + "final_model = tfmot.clustering.keras.strip_clustering(clustered_model)\n", + "\n", + "_, clustered_keras_file = tempfile.mkstemp('.h5')\n", + "print('Saving clustered model to: ', clustered_keras_file)\n", + "tf.keras.models.save_model(final_model, clustered_keras_file, \n", + " include_optimizer=False)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "jZcotzPSVBtu", + "colab_type": "text" + }, + "source": [ + "Then, create compressible models for TFLite. You can convert the clustered model to a format that's runnable on your targeted backend. TensorFlow Lite is an example you can use to deploy to mobile devices." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "v2N47QW6SGPh", + "colab_type": "code", + "colab": {} + }, + "source": [ + "clustered_tflite_file = '/tmp/clustered_mnist.tflite'\n", + "converter = tf.lite.TFLiteConverter.from_keras_model(final_model)\n", + "tflite_clustered_model = converter.convert()\n", + "with open(clustered_tflite_file, 'wb') as f:\n", + " f.write(tflite_clustered_model)\n", + "print('Saved clustered TFLite model to:', clustered_tflite_file)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "S7amG_9XV-w9", + "colab_type": "text" + }, + "source": [ + "Define a helper function to actually compress the models via gzip and measure the zipped size." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "1XJ4QBMpW5JB", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def get_gzipped_model_size(file):\n", + " # It returns the size of the gzipped model in bytes.\n", + " import os\n", + " import zipfile\n", + "\n", + " _, zipped_file = tempfile.mkstemp('.zip')\n", + " with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:\n", + " f.write(file)\n", + "\n", + " return os.path.getsize(zipped_file)" + ], + "execution_count": 44, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "INeAOWRBSGPj", + "colab_type": "text" + }, + "source": [ + "Compare and see that the models are **6x** smaller from clustering" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "SG1MgZCeSGPk", + "colab_type": "code", + "colab": {} + }, + "source": [ + "print(\"Size of gzipped baseline Keras model: %.2f bytes\" % (get_gzipped_model_size(keras_file)))\n", + "print(\"Size of gzipped clustered Keras model: %.2f bytes\" % (get_gzipped_model_size(clustered_keras_file)))\n", + "print(\"Size of gzipped clustered TFlite model: %.2f bytes\" % (get_gzipped_model_size(clustered_tflite_file)))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5TOgpEGfSGPn", + "colab_type": "text" + }, + "source": [ + "## Create an **8x** smaller TFLite model from combining weight clustering and post-training quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "BQb50aC3SGPn", + "colab_type": "text" + }, + "source": [ + "You can apply post-training quantization to the clustered model for additional benefits." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "XyHC8euLSGPo", + "colab_type": "code", + "colab": {} + }, + "source": [ + "converter = tf.lite.TFLiteConverter.from_keras_model(final_model)\n", + "converter.optimizations = [tf.lite.Optimize.DEFAULT]\n", + "tflite_quant_model = converter.convert()\n", + "\n", + "_, quantized_and_clustered_tflite_file = tempfile.mkstemp('.tflite')\n", + "\n", + "with open(quantized_and_clustered_tflite_file, 'wb') as f:\n", + " f.write(tflite_quant_model)\n", + "\n", + "print('Saved quantized and clustered TFLite model to:', quantized_and_clustered_tflite_file)\n", + "print(\"Size of gzipped baseline Keras model: %.2f bytes\" % (get_gzipped_model_size(keras_file)))\n", + "print(\"Size of gzipped clustered and quantized TFlite model: %.2f bytes\" % (get_gzipped_model_size(quantized_and_clustered_tflite_file)))" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U-yBcocGSGPv", + "colab_type": "text" + }, + "source": [ + "## See the persistence of accuracy from TF to TFLite" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Jh_pcf0XSGPv", + "colab_type": "text" + }, + "source": [ + "Define a helper function to evaluate the TFLite model on the test dataset." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "EJ9B7pRISGPw", + "colab_type": "code", + "colab": {} + }, + "source": [ + "def eval_model(interpreter):\n", + " input_index = interpreter.get_input_details()[0][\"index\"]\n", + " output_index = interpreter.get_output_details()[0][\"index\"]\n", + "\n", + " # Run predictions on every image in the \"test\" dataset.\n", + " prediction_digits = []\n", + " for i, test_image in enumerate(test_images):\n", + " if i % 1000 == 0:\n", + " print('Evaluated on {n} results so far.'.format(n=i))\n", + " # Pre-processing: add batch dimension and convert to float32 to match with\n", + " # the model's input data format.\n", + " test_image = np.expand_dims(test_image, axis=0).astype(np.float32)\n", + " interpreter.set_tensor(input_index, test_image)\n", + "\n", + " # Run inference.\n", + " interpreter.invoke()\n", + "\n", + " # Post-processing: remove batch dimension and find the digit with highest\n", + " # probability.\n", + " output = interpreter.tensor(output_index)\n", + " digit = np.argmax(output()[0])\n", + " prediction_digits.append(digit)\n", + "\n", + " print('\\n')\n", + " # Compare prediction results with ground truth labels to calculate accuracy.\n", + " prediction_digits = np.array(prediction_digits)\n", + " accuracy = (prediction_digits == test_labels).mean()\n", + " return accuracy" + ], + "execution_count": 47, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0swuxbPmSGPy", + "colab_type": "text" + }, + "source": [ + "You evaluate the model, which has been clustered and quantized, and then see the accuracy from TensorFlow persists to the TFLite backend." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "RFD4LXjpSGPz", + "colab_type": "code", + "colab": {} + }, + "source": [ + "interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)\n", + "interpreter.allocate_tensors()\n", + "\n", + "test_accuracy = eval_model(interpreter)\n", + "\n", + "print('Clustered and quantized TFLite test_accuracy:', test_accuracy)\n", + "print('Clustered TF test accuracy:', clustered_model_accuracy)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JgXTEXC7SGP1", + "colab_type": "text" + }, + "source": [ + "## Conclusion" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7JhbpowqSGP1", + "colab_type": "text" + }, + "source": [ + "In this tutorial, you saw how to create clustered models with the TensorFlow Model Optimization Toolkit API. More specifically, you've been through an end-to-end example for creating an 8x smaller model for MNIST with minimal accuracy difference. We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.\n" + ] + } + ] +} \ No newline at end of file