diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb index fdee905bf..96b906b0c 100644 --- a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb @@ -1,33 +1,4 @@ { - "nbformat": 4, - "nbformat_minor": 0, - "metadata": { - "celltoolbar": "Raw Cell Format", - "colab": { - "name": "clustering_comprehensive_guide.ipynb", - "provenance": [], - "private_outputs": true, - "collapsed_sections": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.6.9" - } - }, "cells": [ { "cell_type": "markdown", @@ -41,12 +12,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { "cellView": "form", + "colab": {}, "colab_type": "code", - "id": "ITj3u97-tNR7", - "colab": {} + "id": "ITj3u97-tNR7" }, + "outputs": [], "source": [ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", "# you may not use this file except in compliance with the License.\n", @@ -59,9 +32,7 @@ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", "# See the License for the specific language governing permissions and\n", "# limitations under the License." - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -80,20 +51,20 @@ "id": "IFva_Ed5N4ru" }, "source": [ - "\n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " View on TensorFlow.org\n", - " \n", - " Run in Google Colab\n", - " \n", - " View source on GitHub\n", - " \n", - " Download notebook\n", - "
" + "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/model-optimization/blob/master/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\n", + " \u003c/td\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/model-optimization/tensorflow_model_optimization/g3doc/guide/clustering/clustering_comprehensive_guide.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" ] }, { @@ -114,8 +85,7 @@ "* Define a clustered model.\n", "* Checkpoint and deserialize a clustered model.\n", "* Improve the accuracy of the clustered model.\n", - "* For deployment only, you must take steps to see compression benefits.\n", - "\n" + "* For deployment only, you must take steps to see compression benefits.\n" ] }, { @@ -130,12 +100,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { + "colab": {}, "colab_type": "code", "id": "08dJRvOqN4rw", - "scrolled": true, - "colab": {} + "scrolled": true }, + "outputs": [], "source": [ "! pip install -q tensorflow-model-optimization\n", "\n", @@ -203,9 +175,7 @@ "\n", "setup_model()\n", "pretrained_weights = setup_pretrained_weights()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -244,12 +214,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { + "colab": {}, "colab_type": "code", "id": "29g7OADjN4r1", - "scrolled": true, - "colab": {} + "scrolled": true }, + "outputs": [], "source": [ "import tensorflow_model_optimization as tfmot\n", "\n", @@ -267,9 +239,7 @@ "clustered_model = cluster_weights(model, **clustering_params)\n", "\n", "clustered_model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -278,8 +248,7 @@ "id": "zEOHK4OON4r7" }, "source": [ - "### Cluster some layers (sequential and functional models)\n", - "\n" + "### Cluster some layers (sequential and functional models)\n" ] }, { @@ -301,12 +270,14 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { + "colab": {}, "colab_type": "code", "id": "IqBdl3uJN4r_", - "scrolled": true, - "colab": {} + "scrolled": true }, + "outputs": [], "source": [ "# Create a base model\n", "base_model = setup_model()\n", @@ -327,9 +298,7 @@ ")\n", "\n", "clustered_model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -353,11 +322,13 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { + "colab": {}, "colab_type": "code", - "id": "w7P67mPk6RkQ", - "colab": {} + "id": "w7P67mPk6RkQ" }, + "outputs": [], "source": [ "# Define the model.\n", "base_model = setup_model()\n", @@ -373,9 +344,7 @@ " loaded_model = tf.keras.models.load_model(keras_model_file)\n", "\n", "loaded_model.summary()" - ], - "execution_count": null, - "outputs": [] + ] }, { "cell_type": "markdown", @@ -435,11 +404,13 @@ }, { "cell_type": "code", + "execution_count": null, "metadata": { + "colab": {}, "colab_type": "code", - "id": "ZvuiCBsVN4sR", - "colab": {} + "id": "ZvuiCBsVN4sR" }, + "outputs": [], "source": [ "model = setup_model()\n", "clustered_model = cluster_weights(model, **clustering_params)\n", @@ -465,9 +436,21 @@ " % (get_gzipped_model_size(clustered_model)))\n", "print(\"Size of gzipped clustered model with stripping: %.2f bytes\" \n", " % (get_gzipped_model_size(final_model)))" - ], - "execution_count": null, - "outputs": [] + ] } - ] + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "clustering_comprehensive_guide.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb index 43dffc86f..d71b55da4 100755 --- a/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/clustering/clustering_example.ipynb @@ -239,8 +239,7 @@ "id": "Y2wKK7w9SGPS" }, "source": [ - "Apply the `cluster_weights()` API to a whole pre-trained model to demonstrate its effectiveness in reducing the model size after applying zip while keeping decent accuracy. For how best to balance the accuracy and compression rate for your use case, please refer to the per layer example in the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide).\n", - "\n" + "Apply the `cluster_weights()` API to a whole pre-trained model to demonstrate its effectiveness in reducing the model size after applying zip while keeping decent accuracy. For how best to balance the accuracy and compression rate for your use case, please refer to the per layer example in the [comprehensive guide](https://www.tensorflow.org/model_optimization/guide/clustering/clustering_comprehensive_guide).\n" ] }, { @@ -646,26 +645,12 @@ "colab": { "collapsed_sections": [], "name": "clustering_example.ipynb", - "private_outputs": true, "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3", - "language": "python", "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.7" } }, "nbformat": 4, diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb index 081b273fa..aedf169b7 100644 --- a/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/pruning/comprehensive_guide.ipynb @@ -12,9 +12,9 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "cellView": "both", + "cellView": "form", "colab": {}, "colab_type": "code", "id": "IcfrhafzkZbH" @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "both", "colab": {}, @@ -210,13 +210,12 @@ "* Try \"Prune some layers\" to skip pruning the layers that reduce accuracy the most.\n", "* It's generally better to finetune with pruning as opposed to training from scratch.\n", "\n", - "To make the whole model train with pruning, apply `tfmot.sparsity.keras.prune_low_magnitude` to the model.\n", - "\n" + "To make the whole model train with pruning, apply `tfmot.sparsity.keras.prune_low_magnitude` to the model.\n" ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -264,7 +263,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -305,7 +304,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -351,7 +350,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -380,7 +379,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -426,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -483,7 +482,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -554,7 +553,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -664,7 +663,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -695,7 +694,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -743,7 +742,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -795,7 +794,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -820,8 +819,7 @@ "collapsed_sections": [ "Tce3stUlHN0L" ], - "name": "Pruning comprehensive guide", - "private_outputs": true, + "name": "comprehensive_guide.ipynb", "provenance": [], "toc_visible": true }, diff --git a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb index 51de544ce..4d5f7d5bc 100644 --- a/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/pruning/pruning_with_keras.ipynb @@ -690,8 +690,7 @@ "\n", "You created a 10x smaller model for MNIST, with minimal accuracy difference.\n", "\n", - "We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.\n", - "\n" + "We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments.\n" ] } ], @@ -699,8 +698,7 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [], - "name": "Pruning in Keras example", - "private_outputs": true, + "name": "pruning_with_keras.ipynb", "provenance": [], "toc_visible": true }, diff --git a/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb b/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb index b6a9ac752..23900d19f 100644 --- a/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/quantization/training_comprehensive_guide.ipynb @@ -12,9 +12,9 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { - "cellView": "both", + "cellView": "form", "colab": {}, "colab_type": "code", "id": "IcfrhafzkZbH" @@ -116,7 +116,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "both", "colab": {}, @@ -228,13 +228,12 @@ "id": "_Zhzx_azO1WR" }, "source": [ - "To make the whole model aware of quantization, apply `tfmot.quantization.keras.quantize_model` to the model.\n", - "\n" + "To make the whole model aware of quantization, apply `tfmot.quantization.keras.quantize_model` to the model.\n" ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -274,8 +273,7 @@ "**Tips for better model accuracy:**\n", "* It's generally better to finetune with quantization aware training as opposed to training from scratch.\n", "* Try quantizing the later layers instead of the first layers.\n", - "* Avoid quantizing critical layers (e.g. attention mechanism).\n", - "\n" + "* Avoid quantizing critical layers (e.g. attention mechanism).\n" ] }, { @@ -290,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -334,7 +332,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -377,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -412,7 +410,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -455,7 +453,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -504,7 +502,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -571,14 +569,12 @@ "Below is an example that defines the same `QuantizeConfig` used for the `Dense` layer in the API defaults.\n", "\n", "During the forward propagation in this example, the `LastValueQuantizer` returned in `get_weights_and_quantizers` is called with `layer.kernel` as the input, producing an output. The output replaces `layer.kernel`\n", - "in the original forward propagation of the `Dense` layer, via the logic defined in `set_quantize_weights`. The same idea applies to the activations and outputs.\n", - "\n", - "\n" + "in the original forward propagation of the `Dense` layer, via the logic defined in `set_quantize_weights`. The same idea applies to the activations and outputs.\n" ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -623,8 +619,7 @@ "id": "8vJeoGQG9ZX0" }, "source": [ - "### Quantize custom Keras layer\n", - "\n" + "### Quantize custom Keras layer\n" ] }, { @@ -640,13 +635,12 @@ "the \"Experiment with quantization\" use cases.\n", " * Apply `tfmot.quantization.keras.quantize_annotate_layer` to the `CustomLayer` and pass in the `QuantizeConfig`.\n", " * Use\n", - "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults.\n", - "\n" + "`tfmot.quantization.keras.quantize_annotate_model` to continue to quantize the rest of the model with the API defaults.\n" ] }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -702,7 +696,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -736,7 +730,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -766,8 +760,7 @@ "id": "bJMKgzh84CCs" }, "source": [ - "### Modify parts of layer to quantize\n", - "\n" + "### Modify parts of layer to quantize\n" ] }, { @@ -782,7 +775,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -821,7 +814,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -851,8 +844,7 @@ "id": "yD0sIR6tmmRx" }, "source": [ - "### Use custom quantization algorithm\n", - "\n" + "### Use custom quantization algorithm\n" ] }, { @@ -873,7 +865,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -923,7 +915,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -952,8 +944,7 @@ "collapsed_sections": [ "Tce3stUlHN0L" ], - "name": "Quantization aware training comprehensive guide", - "private_outputs": true, + "name": "training_comprehensive_guide.ipynb", "provenance": [], "toc_visible": true }, diff --git a/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb b/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb index 6cabe3da1..759594420 100644 --- a/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb +++ b/tensorflow_model_optimization/g3doc/guide/quantization/training_example.ipynb @@ -12,7 +12,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "cellView": "form", "colab": {}, @@ -109,7 +109,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -124,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -152,7 +152,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -227,7 +227,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -272,7 +272,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -299,7 +299,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -339,7 +339,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -375,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -427,7 +427,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -467,7 +467,7 @@ }, { "cell_type": "code", - "execution_count": 0, + "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", @@ -515,8 +515,7 @@ "You saw a 4x model size compression benefit for a model for MNIST, with minimal accuracy\n", "difference. To see the latency benefits on mobile, try out the TFLite examples [in the TFLite app repository](https://www.tensorflow.org/lite/models).\n", "\n", - "We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments. \n", - "\n" + "We encourage you to try this new capability, which can be particularly important for deployment in resource-constrained environments. \n" ] } ], @@ -524,8 +523,7 @@ "accelerator": "GPU", "colab": { "collapsed_sections": [], - "name": "Quantization aware training in Keras example", - "private_outputs": true, + "name": "training_example.ipynb", "provenance": [], "toc_visible": true }, diff --git a/tensorflow_model_optimization/python/core/clustering/keras/BUILD b/tensorflow_model_optimization/python/core/clustering/keras/BUILD index 56b3df1e9..1cf283dbb 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/BUILD +++ b/tensorflow_model_optimization/python/core/clustering/keras/BUILD @@ -74,6 +74,16 @@ py_library( ], ) +py_library( + name = "clustering_callbacks", + srcs = ["clustering_callbacks.py"], + srcs_version = "PY3", + visibility = ["//visibility:public"], + deps = [ + # tensorflow dep1, + ], +) + py_test( name = "cluster_test", size = "medium", diff --git a/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py new file mode 100644 index 000000000..3aa3e01cb --- /dev/null +++ b/tensorflow_model_optimization/python/core/clustering/keras/clustering_callbacks.py @@ -0,0 +1,94 @@ +# Copyright 2020 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Callbacks for Clustering.""" + +import tensorflow as tf + +from tensorflow import keras +from tensorflow_model_optimization.python.core.keras import compat + + +class ClusteringSummaries(keras.callbacks.TensorBoard): + """Helper class to create tensorboard summaries for the clustering progress. + + This class is derived from tf.keras.callbacks.TensorBoard and just adds + functionality to write histograms with batch-wise frequency. + + Arguments: + log_dir: The path to the directory where the log files are saved + cluster_update_freq: determines the frequency of updates of the + clustering histograms. Same behaviour as parameter update_freq of the + base class, i.e. it accepts `'batch'`, `'epoch'` or integer. + """ + + def __init__(self, log_dir='logs', cluster_update_freq='epoch', **kwargs): + super(ClusteringSummaries, self).__init__(log_dir=log_dir, **kwargs) + + if not isinstance(log_dir, str) or not log_dir: + raise ValueError( + '`log_dir` must be a non-empty string. You passed `log_dir`=' + '{input}.'.format(input=log_dir)) + + self.cluster_update_freq = \ + 1 if cluster_update_freq == 'batch' else cluster_update_freq + + if compat.is_v1_apis(): # TF 1.X + self.writer = tf.compat.v1.summary.FileWriter(log_dir) + else: # TF 2.X + self.writer = tf.summary.create_file_writer(log_dir) + + self.continuous_batch = 0 + + def on_train_batch_begin(self, batch, logs=None): + super().on_train_batch_begin(batch, logs) + # Count batches manually to get a continuous batch count spanning + # epochs, because the function parameter 'batch' is reset to zero + # every epoch. + self.continuous_batch += 1 + + def on_train_batch_end(self, batch, logs=None): + assert self.continuous_batch >= batch, \ + ('Continuous batch count must always be greater or equal than the ' + 'batch count from the parameter in the current epoch.') + + super().on_train_batch_end(batch, logs) + + if self.cluster_update_freq == 'epoch': + return + elif self.continuous_batch % self.cluster_update_freq != 0: + return # skip this batch + + self._write_summary() + + def on_epoch_end(self, epoch, logs=None): + super().on_epoch_end(epoch, logs) + if self.cluster_update_freq == 'epoch': + self._write_summary() + + def _write_summary(self): + with self.writer.as_default(): + for layer in self.model.layers: + if not hasattr(layer, 'layer') or not hasattr( + layer.layer, 'get_clusterable_weights'): + continue # skip layer + clusterable_weights = layer.layer.get_clusterable_weights() + if len(clusterable_weights) < 1: + continue # skip layers without clusterable weights + prefix = 'clustering/' + # Log variables + for var in layer.variables: + success = tf.summary.histogram( + prefix + var.name, var, step=self.continuous_batch) + assert success diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD index 33e126ff2..dd41ba082 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/BUILD @@ -22,5 +22,7 @@ py_binary( # python/keras tensorflow dep2, # python/keras/optimizer_v2 tensorflow dep2, "//tensorflow_model_optimization/python/core/clustering/keras:cluster", + "//tensorflow_model_optimization/python/core/clustering/keras:cluster_config", + "//tensorflow_model_optimization/python/core/clustering/keras:clustering_callbacks", ], ) diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py index a9559ab89..cb4a48cee 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py @@ -13,26 +13,28 @@ # limitations under the License. # ============================================================================== # pylint: disable=missing-docstring -"""Train a simple convnet on the MNIST dataset and cluster it. - -This example is based on the sample that can be found here: -https://www.tensorflow.org/model_optimization/guide/quantization/training_example -""" +"""Train a simple convnet on the MNIST dataset.""" from __future__ import print_function +import datetime +import os + from absl import app as absl_app from absl import flags import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +from tensorflow_model_optimization.python.core.clustering.keras import clustering_callbacks keras = tf.keras +l = keras.layers FLAGS = flags.FLAGS batch_size = 128 +num_classes = 10 epochs = 12 epochs_fine_tuning = 4 @@ -41,137 +43,158 @@ 'Output directory to hold tensorboard events') -def load_mnist_dataset(): - mnist = keras.datasets.mnist - (train_images, train_labels), (test_images, test_labels) = mnist.load_data() - - # Normalize the input image so that each pixel value is between 0 to 1. - train_images = train_images / 255.0 - test_images = test_images / 255.0 - - return (train_images, train_labels), (test_images, test_labels) - - -def build_sequential_model(): - "Define the model architecture." - - return keras.Sequential([ - keras.layers.InputLayer(input_shape=(28, 28)), - keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'), - keras.layers.MaxPooling2D(pool_size=(2, 2)), - keras.layers.Flatten(), - keras.layers.Dense(10) +def build_sequential_model(input_shape): + return tf.keras.Sequential([ + l.Conv2D( + 32, 5, padding='same', activation='relu', input_shape=input_shape), + l.MaxPooling2D((2, 2), (2, 2), padding='same'), + l.BatchNormalization(), + l.Conv2D(64, 5, padding='same', activation='relu'), + l.MaxPooling2D((2, 2), (2, 2), padding='same'), + l.Flatten(), + l.Dense(1024, activation='relu'), + l.Dropout(0.4), + l.Dense(num_classes, activation='softmax') ]) -def train_model(model, x_train, y_train, x_test, y_test): - model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer='adam', - metrics=['accuracy']) - - # Print the model summary. - model.summary() - - # Model needs to be clustered after initial training - # and having achieved good accuracy - model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs, - verbose=1, - validation_split=0.1) - - score = model.evaluate(x_test, y_test, verbose=0) - print('Test loss:', score[0]) - print('Test accuracy:', score[1]) - - return model - - -def cluster_model(model, x_train, y_train, x_test, y_test): - print('Clustering model') - - clustering_params = { - 'number_of_clusters': 8, - 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED - } - - # Cluster model - clustered_model = cluster.cluster_weights(model, **clustering_params) - - # Use smaller learning rate for fine-tuning - # clustered model - opt = tf.keras.optimizers.Adam(learning_rate=1e-5) - - clustered_model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=opt, - metrics=['accuracy']) - - # Fine-tune clustered model - clustered_model.fit( - x_train, - y_train, - batch_size=batch_size, - epochs=epochs_fine_tuning, - verbose=1, - validation_split=0.1) - - score = clustered_model.evaluate(x_test, y_test, verbose=0) - print('Clustered model test loss:', score[0]) - print('Clustered model test accuracy:', score[1]) - - return clustered_model - - -def test_clustered_model(clustered_model, x_test, y_test): - # Ensure accuracy persists after serializing/deserializing the model - clustered_model.save('clustered_model.h5') - # To deserialize the clustered model, use the clustering scope - with cluster.cluster_scope(): - loaded_clustered_model = keras.models.load_model('clustered_model.h5') - - # Checking that the deserialized model's accuracy matches the clustered model - score = loaded_clustered_model.evaluate(x_test, y_test, verbose=0) - print('Deserialized model test loss:', score[0]) - print('Deserialized model test accuracy:', score[1]) - - # Ensure accuracy persists after stripping the model - stripped_model = cluster.strip_clustering(loaded_clustered_model) - stripped_model.compile( - loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), +def build_functional_model(input_shape): + inp = tf.keras.Input(shape=input_shape) + x = l.Conv2D(32, 5, padding='same', activation='relu')(inp) + x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) + x = l.BatchNormalization()(x) + x = l.Conv2D(64, 5, padding='same', activation='relu')(x) + x = l.MaxPooling2D((2, 2), (2, 2), padding='same')(x) + x = l.Flatten()(x) + x = l.Dense(1024, activation='relu')(x) + x = l.Dropout(0.4)(x) + out = l.Dense(num_classes, activation='softmax')(x) + + return tf.keras.models.Model([inp], [out]) + +def train_and_save(models, x_train, y_train, x_test, y_test): + for model in models: + model.compile( + loss=tf.keras.losses.categorical_crossentropy, + optimizer='adam', + metrics=['accuracy']) + + # Print the model summary. + model.summary() + + # Model needs to be clustered after initial training + # and having achieved good accuracy + model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs, + verbose=1, + validation_data=(x_test, y_test)) + score = model.evaluate(x_test, y_test, verbose=0) + print('Test loss:', score[0]) + print('Test accuracy:', score[1]) + + print('Clustering model') + + clustering_params = { + 'number_of_clusters': 8, + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED + } + + # Cluster model + clustered_model = cluster.cluster_weights(model, **clustering_params) + + # Use smaller learning rate for fine-tuning + # clustered model + opt = tf.keras.optimizers.Adam(learning_rate=1e-5) + + clustered_model.compile( + loss=tf.keras.losses.categorical_crossentropy, + optimizer=opt, + metrics=['accuracy']) + + # Add callback for tensorboard summaries + log_dir = os.path.join( + FLAGS.output_dir, + datetime.datetime.now().strftime("%Y%m%d-%H%M%S-clustering")) + callbacks = [ + clustering_callbacks.ClusteringSummaries( + log_dir, + cluster_update_freq='epoch', + update_freq='batch', + histogram_freq=1) + ] + + # Fine-tune model + clustered_model.fit( + x_train, + y_train, + batch_size=batch_size, + epochs=epochs_fine_tuning, + verbose=1, + callbacks=callbacks, + validation_data=(x_test, y_test)) + + score = clustered_model.evaluate(x_test, y_test, verbose=0) + print('Clustered Model Test loss:', score[0]) + print('Clustered Model Test accuracy:', score[1]) + + #Ensure accuracy persists after stripping the model + stripped_model = cluster.strip_clustering(clustered_model) + + stripped_model.compile( + loss=tf.keras.losses.categorical_crossentropy, optimizer='adam', metrics=['accuracy']) + stripped_model.save('stripped_model.h5') - # Checking that the stripped model's accuracy matches the clustered model - score = stripped_model.evaluate(x_test, y_test, verbose=0) - print('Stripped model test loss:', score[0]) - print('Stripped model test accuracy:', score[1]) + # To acquire the stripped model, + # deserialize with clustering scope + with cluster.cluster_scope(): + loaded_model = keras.models.load_model('stripped_model.h5') + # Checking that the stripped model's accuracy matches the clustered model + score = loaded_model.evaluate(x_test, y_test, verbose=0) + print('Stripped Model Test loss:', score[0]) + print('Stripped Model Test accuracy:', score[1]) def main(unused_argv): if FLAGS.enable_eager: print('Running in Eager mode.') tf.compat.v1.enable_eager_execution() - # the data, shuffled and split between train and test sets - (x_train, y_train), (x_test, y_test) = load_mnist_dataset() + # input image dimensions + img_rows, img_cols = 28, 28 + # the data, shuffled and split between train and test sets + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data() + + if tf.keras.backend.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) + else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + + x_train = x_train.astype('float32') + x_test = x_test.astype('float32') + x_train /= 255 + x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') - # Build model - model = build_sequential_model() - # Train model - model = train_model(model, x_train, y_train, x_test, y_test) - # Cluster and fine-tune model - clustered_model = cluster_model(model, x_train, y_train, x_test, y_test) - # Test clustered model (serialize/deserialize, strip clustering) - test_clustered_model(clustered_model, x_test, y_test) + # convert class vectors to binary class matrices + y_train = tf.keras.utils.to_categorical(y_train, num_classes) + y_test = tf.keras.utils.to_categorical(y_test, num_classes) + + sequential_model = build_sequential_model(input_shape) + functional_model = build_functional_model(input_shape) + models = [sequential_model, functional_model] + train_and_save(models, x_train, y_train, x_test, y_test) if __name__ == '__main__':