diff --git a/tfx/examples/cifar10/README.md b/tfx/examples/cifar10/README.md index 7a524c7a53..f7f7d8174a 100644 --- a/tfx/examples/cifar10/README.md +++ b/tfx/examples/cifar10/README.md @@ -1,8 +1,6 @@ -# CIFAR-10 Transfer Learning and MLKit integration Example -This example illustrates how to use Transfer Learning for image classification -with TFX, and use trained model to do object detection with -[MLKit](https://developers.google.com/ml-kit) +# CIFAR-10 Transfer Learning and MLKit integration Example +This example illustrates how to use Transfer Learning for image classification with TFX, and use trained model to do object detection with [MLKit](https://developers.google.com/ml-kit) ## Instruction @@ -26,36 +24,41 @@ version of TF2 will be installed automatically). ``` pip install -e cifar10/ -# The following is needed until tensorflow-model-analysis 0.23.0 is released -pip uinstall tensorflow-model-analysis -pip install git+https://github.com/tensorflow/model-analysis.git#egg=tensorflow_model_analysis ``` ### Dataset - -There is a subset of CIFAR10 (128 images) available in the data folder. To -prepare the whole dataset, first create a script and run the following Python -code: `import tensorflow_datasets as tfds ds = tfds.load('cifar10', -data_dir='./cifar10/data/',split=['train', 'test'])` Then, create sub-folders -for different dataset splits and move different splits to corresponding folders. -`cd cifar10/data mkdir train_whole mkdir test_whole mv -cifar10/3.0.2/cifar10-train.tfrecord-00000-of-00001 train_whole mv -cifar10/3.0.2/cifar10-test.tfrecord-00000-of-00001 test_whole` You'll find the -final dataset under `train_whole` and `test_whole` folders. Finally, clean up -the data folder. `rm -r cifar10` - +There is a subset of CIFAR10 (128 images) available in the data folder. To prepare the whole dataset, first create a script and run the following Python code: +``` +import tensorflow_datasets as tfds +ds = tfds.load('cifar10', data_dir='./cifar10/data/',split=['train', 'test']) +``` +Then, create sub-folders for different dataset splits and move different splits to corresponding folders. +``` +cd cifar10/data +mkdir train_whole +mkdir test_whole +mv cifar10/3.0.2/cifar10-train.tfrecord-00000-of-00001 train_whole +mv cifar10/3.0.2/cifar10-test.tfrecord-00000-of-00001 test_whole +``` +You'll find the final dataset under `train_whole` and `test_whole` folders. +Finally, clean up the data folder. +``` +rm -r cifar10 +``` ### Train the model +Execute the pipeline python file : +``` +python ~/cifar10/cifar_pipeline_native_keras.py +``` +The trained model is located at `~/cifar10/serving_model_lite/tflite` -Execute the pipeline python file : `python -~/cifar10/cifar_pipeline_native_keras.py` The trained model is located at -`~/cifar10/serving_model_lite/tflite` +This model is ready to be used for object detection with MLKit. Follow MLKit's [documentation](https://developers.google.com/ml-kit/vision/object-detection/custom-models/android) to set up an App and use it. -This model is ready to be used for object detection with MLKit. Follow MLKit's -[documentation](https://developers.google.com/ml-kit/vision/object-detection/custom-models/android) -to set up an App and use it. +### Train the model on GKE -## Acknowledge Data Source +To speed up model training with a distributed node pool on Google Kubernetes Engine (GKE), you may follow the instructions in `distributed/README.md`. +## Acknowledge Data Source ``` @TECHREPORT{Krizhevsky09learningmultiple, author = {Alex Krizhevsky}, diff --git a/tfx/examples/cifar10/distributed/README.md b/tfx/examples/cifar10/distributed/README.md new file mode 100644 index 0000000000..713dc95dec --- /dev/null +++ b/tfx/examples/cifar10/distributed/README.md @@ -0,0 +1,89 @@ +# CIFAR-10 Transfer Learning and MLKit integration Example with Distributed Training + +This example illustrates how to modify the base example for distributed training with a distributed node pool on Google Kubernetes Engine (GKE). + +## Instructions + +### Set Up a Node Pool +The guide assumes that you have command line access to `kubectl`. +If you do not have access to a GKE cluster, follow the quickstart guide [here](https://cloud.google.com/kubernetes-engine/docs/quickstart) to start one. + +Then, follow the instructions below for how to set up a node pool on GKE: +https://cloud.google.com/kubernetes-engine/docs/how-to/node-pools + +If you wish to use GPUs in your training, follow the instructions here: +https://cloud.google.com/kubernetes-engine/docs/how-to/gpus + +**Important:** In order to have write access to the GCS buckets in the project, you need to set up the node pool to have full service account scope access. If you are using the console, you can go to +Security -> Access scopes and select "Allow full access to all Cloud APIs". If you are using gcloud, +you configure this with the `--scopes` flag. + +### Set Up a Custom Image for Training (Optional) +The base CIFAR10 example relies on dependencies including TensorFlowJS that do not come with the default TFX installation. +If you do not wish to use a custom Docker image, you can remove the TFLite model rewriting portion in `cifar10_utils_native_keras`. + +Start from a base TensorFlow GPU image in your Dockerfile: + +``` +FROM tensorflow/tensorflow:latest-gpu +``` + +Then, follow the instructions in the base example to configure your training environment. For the example dataset to work, you would need ro run the following commands: +``` +RUN git clone https://github.com/tensorflow/tfx ~/tfx-source && \ + pushd ~/tfx-source && \ + cp -r ~/tfx-source/tfx/examples/cifar10 ~/ +RUN pip install -e cifar10/ +``` + +Finally, build and upload the new image to your container registry: +``` +docker build -t $YOUR_IMAGE_NAME . +docker push $YOUR_IMAGE_NAME +``` + +### Modify Pipeline and Util Files for Distributed Training + +First, modify the `_cifar10_root` in `cifar10_pipeline_native_keras` to point to a remote directory (For example, `gs://YOUR_GCS_PATH`), +and upload the CIFAR10 base directory there when you are done modifying the files. This allows training worker pods to pull the utility files +from your remote directory during training. + +Then, specify the custom Trainer executor `kubernetes_trainer_executor.GenericExecutor` imported from `tfx.extensions.google_cloud_kubernetes.trainer`: + +``` + trainer = Trainer( + module_file=module_file, + custom_executor_spec=executor_spec.ExecutorClassSpec(kubernetes_trainer_executor.GenericExecutor), + examples=transform.outputs['transformed_examples'], + transform_graph=transform.outputs['transform_graph'], + schema=schema_gen.outputs['schema'], + train_args=trainer_pb2.TrainArgs(num_steps=160), + eval_args=trainer_pb2.EvalArgs(num_steps=4), + custom_config={ + kubernetes_trainer_executor.TRAINING_ARGS_KEY: { + `tfx_image`: None, # SPECIFY CUSTOM IMAGE (IF ANY) + 'num_workers': 1, # SPECIFY NUMBER OF WORKERS HERE + 'num_gpus_per_worker': 0} # SPECIFY NUMBER OF GPUs HERE + } + ) +``` + +You should specify the training resources you wish to use in `custom_config`. Note that this should be consistent with the Node Pool you created in step 1. +If you are using a custom image, supply it under `tfx_image`. + +Finally, replace `cifar10_utils_native_keras` with the one in this directory. You may need to edit `_LABEL_MAP_FILE_PATH` in the file to point to your remote path. + +You should then be able to execute your pipeline with distributed training: +``` +python ~/cifar10/cifar_pipeline_native_keras.py +``` + +## Recommendations + +The following has been tested to yield the best speed-up results compared to training on a single node: + +- Multi-worker training with 4 or 8 replicas with 8 or 16 nodes, >= 64 GB memory +- Single worker training with GPU (i.e. NVIDIA K80) + +If configured correctly, the above two configurations should yield a 60% to 70% speed up in wall time versus single worker with no GPU. If you need an even higher speed up (up to 80%), +consider using multi-worker training with 1 GPU per node. However, cost efficiency will be lower, as the training is bottlenecked by preprocessing of images on CPU. diff --git a/tfx/examples/cifar10/distributed/cifar10_utils_native_keras.py b/tfx/examples/cifar10/distributed/cifar10_utils_native_keras.py new file mode 100644 index 0000000000..9069f886b8 --- /dev/null +++ b/tfx/examples/cifar10/distributed/cifar10_utils_native_keras.py @@ -0,0 +1,426 @@ +# Lint as: python2, python3 +# Copyright 2019 Google LLC. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Python source file includes CIFAR10 utils for Keras model. + +The utilities in this file are used to build a model with native Keras. +This module file will be used in Transform and generic Trainer. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +from typing import List, Text +import absl +import json +import tensorflow as tf +import tensorflow_transform as tft +from tfx.components.trainer import constants + +import flatbuffers +# pylint: disable=g-direct-tensorflow-import +from tflite_support import metadata as _metadata +from tflite_support import metadata_schema_py_generated as _metadata_fb +# pylint: enable=g-direct-tensorflow-import + +from tfx.components.trainer.rewriting import converters +from tfx.components.trainer.rewriting import rewriter +from tfx.components.trainer.rewriting import rewriter_factory + +from tfx.components.trainer.executor import TrainerFnArgs + +# When training on the whole dataset use following constants instead. +# This setting should give ~91% accuracy on the whole test set +# _TRAIN_DATA_SIZE = 50000 +# _EVAL_DATA_SIZE = 10000 +# _TRAIN_BATCH_SIZE = 64 +# _EVAL_BATCH_SIZE = 64 +# _CLASSIFIER_LEARNING_RATE = 3e-4 +# _FINETUNE_LEARNING_RATE = 5e-5 +# _CLASSIFIER_EPOCHS = 12 + +_TRAIN_DATA_SIZE = 128 +_EVAL_DATA_SIZE = 128 +_TRAIN_BATCH_SIZE = 32 +_EVAL_BATCH_SIZE = 32 +_CLASSIFIER_LEARNING_RATE = 1e-3 +_FINETUNE_LEARNING_RATE = 7e-6 +_CLASSIFIER_EPOCHS = 30 + +_IMAGE_KEY = 'image' +_LABEL_KEY = 'label' + +_TFLITE_MODEL_NAME = 'tflite' + +# For distributed training, modify this path to point to a remote location. +_LABEL_MAP_FILE_PATH = 'cifar10/data/labels.txt' + +def _transformed_name(key): + return key + '_xf' + +def _gzip_reader_fn(filenames): + """Small utility returning a record reader that can read gzip'ed files.""" + return tf.data.TFRecordDataset(filenames, compression_type='GZIP') + +def _get_serve_image_fn(model): + """Returns a function that feeds the input tensor into the model.""" + + @tf.function + def serve_image_fn(image_tensor): + """Returns the output to be used in the serving signature. + + Args: + image_tensor: A tensor represeting input image. The image should + have 3 channels. + + Returns: + The model's predicton on input image tensor + """ + return model(image_tensor) + + return serve_image_fn + +def _image_augmentation(image_features): + """Perform image augmentation on batches of images . + + Args: + image_feature: a batch of image features + + Returns: + The augmented image features + """ + batch_size = tf.shape(image_features)[0] + image_features = tf.image.random_flip_left_right(image_features) + image_features = tf.image.resize_with_crop_or_pad(image_features, 250, 250) + image_features = tf.image.random_crop(image_features, + (batch_size, 224, 224, 3)) + return image_features + +def _data_augmentation(feature_dict): + """Perform data augmentation on batches of data. + + Args: + feature_dict: a dict containing features of samples + + Returns: + The feature dict with augmented features + """ + image_features = feature_dict[_transformed_name(_IMAGE_KEY)] + image_features = _image_augmentation(image_features) + feature_dict[_transformed_name(_IMAGE_KEY)] = image_features + return feature_dict + +def _input_fn(file_pattern: List[Text], + tf_transform_output: tft.TFTransformOutput, + is_train: bool = False, + batch_size: int = 200) -> tf.data.Dataset: + """Generates features and label for tuning/training. + + Args: + file_pattern: List of paths or patterns of input tfrecord files. + tf_transform_output: A TFTransformOutput. + is_train: Whether the input dataset is train split or not. + batch_size: representing the number of consecutive elements of returned + dataset to combine in a single batch + + Returns: + A dataset that contains (features, indices) tuple where features is a + dictionary of Tensors, and indices is a single Tensor of label indices. + """ + transformed_feature_spec = ( + tf_transform_output.transformed_feature_spec().copy()) + dataset = tf.data.experimental.make_batched_features_dataset( + file_pattern=file_pattern, + batch_size=batch_size, + features=transformed_feature_spec, + reader_num_threads=16, + parser_num_threads=32, + reader=_gzip_reader_fn, + label_key=_transformed_name(_LABEL_KEY)) + + # Apply data augmentation. We have to do data augmentation here because + # we need to apply data agumentation on-the-fly during training. If we put + # it in Transform, it will only be applied once on the whole dataset, which + # will lose the point of data augmentation. + if is_train: + dataset = dataset.map(lambda x, y: (_data_augmentation(x), y)) + + return dataset + +def _freeze_model_by_percentage(model: tf.keras.Model, + percentage: float): + """Freeze part of the model based on specified percentage + + Args: + model: The keras model need to be partially frozen + percentage: the percentage of layers to freeze + """ + if percentage < 0 or percentage > 1: + raise Exception("Freeze percentage should between 0.0 and 1.0") + + if not model.trainable: + raise Exception( + "The model is not trainable, please set model.trainable to True") + + num_layers = len(model.layers) + num_layers_to_freeze = int(num_layers * percentage) + for idx, layer in enumerate(model.layers): + if idx < num_layers_to_freeze: + layer.trainable = False + else: + layer.trainable = True + +def _build_keras_model() -> tf.keras.Model: + """Creates a Image classification model with MobileNet backbone + + Returns: + The image classifcation Keras Model and the backbone MobileNet model + """ + # We create a MobileNet model with weights pre-trained on ImageNet. + # We remove the top classification layer of the MobileNet, which was + # used for classifying ImageNet objects. We will add our own classification + # layer for CIFAR10 later. We use average pooling at the last convolution + # layer to get a 1D vector for classifcation, which is consistent with the + # origin MobileNet setup + base_model = tf.keras.applications.MobileNet( + input_shape=(224, 224, 3), include_top=False, weights='imagenet', + pooling='avg') + + # We add a Dropout layer at the top of MobileNet backbone we just created to prevent + # overfiting, and then a Dense layer to classifying CIFAR10 objects + model = tf.keras.Sequential([ + tf.keras.layers.InputLayer( + input_shape=(224, 224, 3), name=_transformed_name(_IMAGE_KEY)), + base_model, + tf.keras.layers.Dropout(0.1), + tf.keras.layers.Dense(10, activation='softmax') + ]) + + # Freeze the whole MobileNet backbone to first train the top classifer only + _freeze_model_by_percentage(base_model, 1.0) + + model.compile( + loss='sparse_categorical_crossentropy', + optimizer=tf.keras.optimizers.RMSprop(lr=_CLASSIFIER_LEARNING_RATE), + metrics=['sparse_categorical_accuracy']) + model.summary(print_fn=absl.logging.info) + + return model, base_model + +# TFX Transform will call this function. +def preprocessing_fn(inputs): + """tf.transform's callback function for preprocessing inputs. + + Args: + inputs: map from feature keys to raw not-yet-transformed features. + + Returns: + Map from string feature key to transformed feature operations. + """ + outputs = {} + + # tf.io.decode_png function cannot be applied on a batch of data. + # We have to use tf.map_fn + image_features = tf.map_fn(lambda x: tf.io.decode_png(x[0], channels=3), + inputs[_IMAGE_KEY], dtype=tf.uint8) + # image_features = tf.cast(image_features, tf.float32) + image_features = tf.image.resize(image_features, [224, 224]) + image_features = tf.keras.applications.mobilenet.preprocess_input( + image_features) + + outputs[_transformed_name(_IMAGE_KEY)] = image_features + # TODO(b/157064428): Support label transformation for Keras. + # Do not apply label transformation as it will result in wrong evaluation. + outputs[_transformed_name(_LABEL_KEY)] = inputs[_LABEL_KEY] + + return outputs + +def _write_metadata(model_path: str, + label_map_path: str, + mean: list, + std: list): + """Add normalization option and label map TFLite metadata to the model + + Args: + model_path: The path of the TFLite model + label_map_path: The path of the label map file + mean: The mean value used to normalize input image tensor + std: The standard deviation used to normalize input image tensor + """ + + # Creates flatbuffer for model information. + model_meta = _metadata_fb.ModelMetadataT() + + # Creates flatbuffer for model input metadata. + # Here we add the input normalization info to input metadata. + input_meta = _metadata_fb.TensorMetadataT() + input_normalization = _metadata_fb.ProcessUnitT() + input_normalization.optionsType = ( + _metadata_fb.ProcessUnitOptions.NormalizationOptions) + input_normalization.options = _metadata_fb.NormalizationOptionsT() + input_normalization.options.mean = mean + input_normalization.options.std = std + input_meta.processUnits = [input_normalization] + + # Creates flatbuffer for model output metadata. + # Here we add label file to output metadata. + output_meta = _metadata_fb.TensorMetadataT() + label_file = _metadata_fb.AssociatedFileT() + label_file.name = os.path.basename(label_map_path) + label_file.type = _metadata_fb.AssociatedFileType.TENSOR_AXIS_LABELS + output_meta.associatedFiles = [label_file] + + # Creates subgraph to contain input and output information, + # and add subgraph to the model information. + subgraph = _metadata_fb.SubGraphMetadataT() + subgraph.inputTensorMetadata = [input_meta] + subgraph.outputTensorMetadata = [output_meta] + model_meta.subgraphMetadata = [subgraph] + + # Serialize the model metadata buffer we created above using flatbuffer builder. + b = flatbuffers.Builder(0) + b.Finish( + model_meta.Pack(b), + _metadata.MetadataPopulator.METADATA_FILE_IDENTIFIER) + metadata_buf = b.Output() + + # Populates metadata and label file to the model file. + populator = _metadata.MetadataPopulator.with_model_file(model_path) + populator.load_metadata_buffer(metadata_buf) + populator.load_associated_files([label_map_path]) + populator.populate() + + +# TFX Trainer will call this function. +def run_fn(fn_args: TrainerFnArgs): + """Train the model based on given args. + + Args: + fn_args: Holds args used to train the model as name/value pairs. + """ + multi_worker_strategy = tf.distribute.experimental.MultiWorkerMirroredStrategy() + + tf_transform_output = tft.TFTransformOutput(fn_args.transform_output) + + + with multi_worker_strategy.scope(): + + train_dataset = _input_fn(fn_args.train_files, tf_transform_output, + is_train=True, batch_size=_TRAIN_BATCH_SIZE) + eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, + is_train=False, batch_size=_EVAL_BATCH_SIZE) + + model, base_model = _build_keras_model() + + try: + log_dir = fn_args.model_run_dir + except KeyError: + # TODO(b/158106209): use ModelRun instead of Model artifact for logging. + log_dir = os.path.join(os.path.dirname(fn_args.serving_model_dir), 'logs') + + absl.logging.info('Tensorboard logging to {}'.format(log_dir)) + # Write logs to path + tensorboard_callback = tf.keras.callbacks.TensorBoard( + log_dir=log_dir, update_freq='batch', profile_batch=(30, 60)) + + # Our training regime has two phases: we first freeze the backbone and train + # the newly added classifier only, then unfreeze part of the backbone and + # fine-tune with classifier jointly. + steps_per_epoch = int(_TRAIN_DATA_SIZE / (_TRAIN_BATCH_SIZE)) + total_epochs = int(fn_args.train_steps / steps_per_epoch) + if _CLASSIFIER_EPOCHS > total_epochs: + raise Exception('Classifier epochs is greater than the total epochs') + + absl.logging.info('Start training the top classifier') + model.fit( + train_dataset, + epochs=_CLASSIFIER_EPOCHS, + steps_per_epoch=steps_per_epoch, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback] + ) + + absl.logging.info('Start fine-tuning the model') + # Unfreeze the top MobileNet layers and do joint fine-tuning + _freeze_model_by_percentage(base_model, 0.9) + + + with multi_worker_strategy.scope(): + # We need to recompile the model because layer properties have changed + model.compile( + loss='sparse_categorical_crossentropy', + optimizer=tf.keras.optimizers.RMSprop(lr=_FINETUNE_LEARNING_RATE), + metrics=['sparse_categorical_accuracy']) + model.summary(print_fn=absl.logging.info) + + model.fit( + train_dataset, + initial_epoch=_CLASSIFIER_EPOCHS, + epochs=total_epochs, + steps_per_epoch=steps_per_epoch, + validation_data=eval_dataset, + validation_steps=fn_args.eval_steps, + callbacks=[tensorboard_callback] + ) + + # Prepare the TFLite model used for serving in MLKit + signatures = { + 'serving_default': + _get_serve_image_fn( + model).get_concrete_function( + tf.TensorSpec( + shape=[None, 224, 224, 3], + dtype=tf.float32, + name=_transformed_name(_IMAGE_KEY) + )) + } + + + tf_config = json.loads(os.environ.get(constants.TF_CONFIG_ENV) or '{}') + + task_type = tf_config['task']['type'] + task_id = tf_config['task']['index'] + + def _is_chief(task_type, task_id): + """Returns true if this is run in the master (chief) of training cluster.""" + # 'master' is a legacy notation of chief node in distributed training flock. + return task_type in ('chief', None) or (task_type in ('master', 'worker') + and task_id == 0) + + temp_saving_model_dir = os.path.join(fn_args.serving_model_dir, 'temp_{}'.format(task_id)) + model.save(temp_saving_model_dir, save_format='tf', signatures=signatures) + + if _is_chief(task_type, task_id): + tfrw = rewriter_factory.create_rewriter( + rewriter_factory.TFLITE_REWRITER, name='tflite_rewriter', + enable_experimental_new_converter=True) + converters.rewrite_saved_model(temp_saving_model_dir, + fn_args.serving_model_dir, + tfrw, + rewriter.ModelType.TFLITE_MODEL) + + # Add necessary TFLite metadata to the model in order to use it within MLKit + # TODO(dzats@): Handle label map file path more properly, currently hard-coded + tflite_model_path = os.path.join(fn_args.serving_model_dir, + _TFLITE_MODEL_NAME) + # TODO(dzats@): Extend the TFLite rewriter to be able to add TFLite metadata to the model + _write_metadata(model_path=tflite_model_path, + label_map_path=_LABEL_MAP_FILE_PATH, + mean=[127.5], + std=[127.5]) + else: + tf.io.gfile.rmtree(temp_saving_model_dir)