diff --git a/tensorflow_lite_support/examples/task/vision/desktop/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/BUILD new file mode 100644 index 000000000..31ea17dde --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/BUILD @@ -0,0 +1,71 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_binary( + name = "image_classifier_demo", + srcs = ["image_classifier_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:image_classifier", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:classifications_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:image_classifier_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +cc_binary( + name = "object_detector_demo", + srcs = ["object_detector_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:object_detector", + "//tensorflow_lite_support/cc/task/vision/proto:bounding_box_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:class_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:detections_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:object_detector_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/core:lib", + ], +) + +cc_binary( + name = "image_segmenter_demo", + srcs = ["image_segmenter_demo.cc"], + deps = [ + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "//tensorflow_lite_support/cc/task/vision:image_segmenter", + "//tensorflow_lite_support/cc/task/vision/proto:image_segmenter_options_proto_inc", + "//tensorflow_lite_support/cc/task/vision/proto:segmentations_proto_inc", + "//tensorflow_lite_support/cc/task/vision/utils:frame_buffer_common_utils", + "//tensorflow_lite_support/examples/task/vision/desktop/utils:image_utils", + "@com_google_absl//absl/flags:flag", + "@com_google_absl//absl/flags:parse", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/core:lib", + ], +) diff --git a/tensorflow_lite_support/examples/task/vision/desktop/README.md b/tensorflow_lite_support/examples/task/vision/desktop/README.md new file mode 100644 index 000000000..62986caf5 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/README.md @@ -0,0 +1,181 @@ +# CLI Demos for C++ Vision Task APIs + +This folder contains simple command-line tools for easily trying out the C++ +Vision Task APIs. + +## Image Classifier + +#### Prerequisites + +You will need: + +* a TFLite image classification model (e.g. [aiy/vision/classifier/birds_V1] +(https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/2), +a bird classification model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run classification on, e.g.: + +![sparrow](g3doc/sparrow.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L https://tfhub.dev/google/lite-model/aiy/vision/classifier/birds_V1/2?lite-format=tflite \ + -o /tmp/aiy_vision_classifier_birds_V1_2.tflite + +# Run the classification tool: +bazel run -c opt examples/task/vision/desktop:image_classifier_demo -- \ + --model_path=/tmp/aiy_vision_classifier_birds_V1_2.tflite \ + --image_path=$(pwd)/examples/task/vision/desktop/g3doc/sparrow.jpg \ + --max_results=3 +``` + +#### Results + +In the console, you should get: + +``` +Results: + Rank #0: + index : 671 + score : 0.91797 + class name : /m/01bwbt + display name: Passer montanus + Rank #1: + index : 670 + score : 0.00391 + class name : /m/0193xp + display name: Troglodytes hiemalis + Rank #2: + index : 495 + score : 0.00391 + class name : /m/05sjn7 + display name: Mimus gilvus +``` + +## Object Detector + +#### Prerequisites + +TODO(b/161960089): the model used in this example has an off-by-one error in its +label map, which will cause the model to return "cat" instead of "dog" in the +following example. It will soon be updated on tfhub.dev with a fix. + +You will need: + +* a TFLite object detection model (e.g. [ssd_mobilenet_v1] +(https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1), +a generic object detection model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run detection on, e.g.: + +![dogs](g3doc/dogs.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L https://tfhub.dev/tensorflow/lite-model/ssd_mobilenet_v1/1/metadata/1?lite-format=tflite \ + -o /tmp/ssd_mobilenet_v1_1_metadata_1.tflite + +# Run the detection tool: +bazel run -c opt examples/task/vision/desktop:object_detector_demo -- \ + --model_path=/tmp/ssd_mobilenet_v1_1_metadata_1.tflite \ + --image_path=$(pwd)/examples/task/vision/desktop/g3doc/dogs.jpg \ + --output_png=/tmp/detection-output.png \ + --max_results=2 +``` + +#### Results + +In the console, you should get: + +``` +Results saved to: /tmp/detection-output.png +Results: + Detection #0 (red): + Box: (x: 355, y: 133, w: 190, h: 206) + Top-1 class: + index : 17 + score : 0.73828 + class name : dog + Detection #1 (green): + Box: (x: 103, y: 15, w: 138, h: 369) + Top-1 class: + index : 17 + score : 0.73047 + class name : dog +``` + +And `/tmp/detection-output.jpg` should contain: + +![detection-output](g3doc/detection-output.png) + +## Image Segmenter + +#### Prerequisites + +TODO(b/161957922): the model used in this example doesn't include a label map, +which will cause the console output to be less complete than in the example +below. It will soon be updated on tfhub.dev with a fix. + +You will need: + +* a TFLite image segmentation model (e.g. [deeplab_v3] +(https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1), +a generic segmentation model available on TensorFlow Hub), +* a PNG, JPEG or GIF image to run segmentation on, e.g.: + +![plane](g3doc/plane.jpg) + +#### Usage + +In the console, run: + +```bash +# Download the model: +curl \ + -L https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1?lite-format=tflite \ + -o /tmp/deeplabv3_1_metadata_1.tflite + +# Run the segmentation tool: +bazel run -c opt examples/task/vision/desktop:image_segmenter_demo -- \ + --model_path=/tmp/deeplabv3_1_metadata_1.tflite \ + --image_path=$(pwd)/examples/task/vision/desktop/g3doc/plane.jpg \ + --output_mask_png=/tmp/segmentation-output.png +``` + +#### Results + +In the console, you should get: + +``` +Category mask saved to: /tmp/segmentation-output.png +Color Legend: + (r: 000, g: 000, b: 000): + index : 0 + class name : background + (r: 128, g: 000, b: 000): + index : 1 + class name : aeroplane + +# (omitting multiple lines for conciseness) ... + + (r: 128, g: 192, b: 000): + index : 19 + class name : train + (r: 000, g: 064, b: 128): + index : 20 + class name : tv +Tip: use a color picker on the output PNG file to inspect the output mask with +this legend. +``` + +And `/tmp/segmentation-output.jpg` should contain the segmentation mask: + +![segmentation-output](g3doc/segmentation-output.png) diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png new file mode 100644 index 000000000..c8d56f405 Binary files /dev/null and b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/detection-output.png differ diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg new file mode 100644 index 000000000..9db4bee75 Binary files /dev/null and b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/dogs.jpg differ diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg new file mode 100644 index 000000000..0edefa40a Binary files /dev/null and b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/plane.jpg differ diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png new file mode 100644 index 000000000..e871df337 Binary files /dev/null and b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/segmentation-output.png differ diff --git a/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg new file mode 100644 index 000000000..25d213ea4 Binary files /dev/null and b/tensorflow_lite_support/examples/task/vision/desktop/g3doc/sparrow.jpg differ diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc new file mode 100644 index 000000000..c844e90c5 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc @@ -0,0 +1,169 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:image_classifier_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg + +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/image_classifier.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/classifications_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_classifier_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' image classifier model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to classify. The image EXIF orientation " + "flag, if any, is NOT taken into account."); +ABSL_FLAG(int32, max_results, 5, + "Maximum number of classification results to display."); +ABSL_FLAG(float, score_threshold, 0, + "Classification results with a confidence score below this value are " + "rejected. If >= 0, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); +ABSL_FLAG( + std::vector, class_name_whitelist, {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, classification results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); +ABSL_FLAG( + std::vector, class_name_blacklist, {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, classification results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); + +namespace tflite { +namespace support { +namespace task { +namespace vision { + +ImageClassifierOptions BuildOptions() { + ImageClassifierOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + options.set_max_results(absl::GetFlag(FLAGS_max_results)); + if (absl::GetFlag(FLAGS_score_threshold) >= 0) { + options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold)); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_whitelist)) { + options.add_class_name_whitelist(class_name); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_blacklist)) { + options.add_class_name_blacklist(class_name); + } + return options; +} + +void DisplayResult(const ClassificationResult& result) { + std::cout << "Results:\n"; + for (int head = 0; head < result.classifications_size(); ++head) { + if (result.classifications_size() > 1) { + std::cout << absl::StrFormat(" Head index %d:\n", head); + } + const Classifications& classifications = result.classifications(head); + for (int rank = 0; rank < classifications.classes_size(); ++rank) { + const Class& classification = classifications.classes(rank); + std::cout << absl::StrFormat(" Rank #%d:\n", rank); + std::cout << absl::StrFormat(" index : %d\n", + classification.index()); + std::cout << absl::StrFormat(" score : %.5f\n", + classification.score()); + if (classification.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + classification.class_name()); + } + if (classification.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + classification.display_name()); + } + } + } +} + +absl::Status Classify() { + // Build ImageClassifier. + const ImageClassifierOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr image_classifier, + ImageClassifier::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(RgbImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + + // Run classification and display results. + ASSIGN_OR_RETURN(ClassificationResult result, + image_classifier->Classify(*frame_buffer)); + DisplayResult(result); + + // Cleanup and return. + RgbImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace support +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() && + !absl::GetFlag(FLAGS_class_name_blacklist).empty()) { + std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments " + "are mutually exclusive.\n"; + return 1; + } + + // We need to call this to set up global state for Tensorflow, which is used + // internally for decoding various image formats (JPEG, PNG, etc). + tensorflow::port::InitMain(argv[0], &argc, &argv); + + // Run classification. + absl::Status status = tflite::support::task::vision::Classify(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Classification failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc new file mode 100644 index 000000000..cedf14e1d --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc @@ -0,0 +1,196 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:image_segmenter_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg \ +// --output_mask_png=/path/to/output/mask.png + +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/image_segmenter.h" +#include "tensorflow_lite_support/cc/task/vision/proto/image_segmenter_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/segmentations_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' image segmenter model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to segment. The image EXIF orientation " + "flag, if any, is NOT taken into account."); +ABSL_FLAG(std::string, output_mask_png, "", + "Absolute path to the output category mask (confidence masks outputs " + "are not supported by this tool). Must have a '.png' extension."); + +namespace tflite { +namespace support { +namespace task { +namespace vision { + +ImageSegmenterOptions BuildOptions() { + ImageSegmenterOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + // Confidence masks are not supported by this tool: output_type is set to + // CATEGORY_MASK by default. + return options; +} + +absl::Status EncodeMaskToPngFile(const SegmentationResult& result) { + if (result.segmentation_size() != 1) { + return absl::UnimplementedError( + "Image segmentation models with multiple output segmentations are not " + "supported by this tool."); + } + const Segmentation& segmentation = result.segmentation(0); + // Extract raw mask data as a uint8 pointer. + const uint8* raw_mask = + reinterpret_cast(segmentation.category_mask().data()); + + // Create RgbImageData for the output mask. + uint8* pixel_data = static_cast( + malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8))); + RgbImageData mask = {.pixel_data = pixel_data, + .width = segmentation.width(), + .height = segmentation.height()}; + + // Populate RgbImageData from the raw mask and ColoredLabel-s. + for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) { + Segmentation::ColoredLabel colored_label = + segmentation.colored_labels(raw_mask[i]); + pixel_data[3 * i] = colored_label.r(); + pixel_data[3 * i + 1] = colored_label.g(); + pixel_data[3 * i + 2] = colored_label.b(); + } + + // Encode mask as PNG. + RETURN_IF_ERROR( + EncodeRgbImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png))); + std::cout << absl::StrFormat("Category mask saved to: %s\n", + absl::GetFlag(FLAGS_output_mask_png)); + + // Cleanup and return. + RgbImageDataFree(&mask); + return absl::OkStatus(); +} + +absl::Status DisplayColorLegend(const SegmentationResult& result) { + if (result.segmentation_size() != 1) { + return absl::UnimplementedError( + "Image segmentation models with multiple output segmentations are not " + "supported by this tool."); + } + const Segmentation& segmentation = result.segmentation(0); + const int num_labels = segmentation.colored_labels_size(); + + std::cout << "Color Legend:\n"; + for (int index = 0; index < num_labels; ++index) { + Segmentation::ColoredLabel colored_label = + segmentation.colored_labels(index); + std::cout << absl::StrFormat(" (r: %03d, g: %03d, b: %03d):\n", + colored_label.r(), colored_label.g(), + colored_label.b()); + std::cout << absl::StrFormat(" index : %d\n", index); + if (colored_label.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + colored_label.class_name()); + } + if (colored_label.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + colored_label.display_name()); + } + } + std::cout << "Tip: use a color picker on the output PNG file to inspect the " + "output mask with this legend.\n"; + + return absl::OkStatus(); +} + +absl::Status Segment() { + // Build ImageClassifier. + const ImageSegmenterOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr image_segmenter, + ImageSegmenter::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(RgbImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + + // Run segmentation and save category mask. + ASSIGN_OR_RETURN(SegmentationResult result, + image_segmenter->Segment(*frame_buffer)); + RETURN_IF_ERROR(EncodeMaskToPngFile(result)); + + // Display the legend. + RETURN_IF_ERROR(DisplayColorLegend(result)); + + // Cleanup and return. + RgbImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace support +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_output_mask_png).empty()) { + std::cerr << "Missing mandatory 'output_mask_png' argument.\n"; + return 1; + } + if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_mask_png), ".png")) { + std::cerr << "Argument 'output_mask_png' must end with '.png' or '.PNG'\n"; + return 1; + } + + // We need to call this to set up global state for Tensorflow, which is used + // internally for decoding various image formats (JPEG, PNG, etc). + tensorflow::port::InitMain(argv[0], &argc, &argv); + + // Run segmentation. + absl::Status status = tflite::support::task::vision::Segment(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Segmentation failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc new file mode 100644 index 000000000..84f679112 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc @@ -0,0 +1,247 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Example usage: +// bazel run -c opt \ +// tensorflow_lite_support/examples/task/vision/desktop:object_detector_demo \ +// -- \ +// --model_path=/path/to/model.tflite \ +// --image_path=/path/to/image.jpg \ +// --output_png=/path/to/output.png + +#include +#include + +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow_lite_support/cc/port/statusor.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/object_detector.h" +#include "tensorflow_lite_support/cc/task/vision/proto/bounding_box_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/class_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/detections_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/proto/object_detector_options_proto_inc.h" +#include "tensorflow_lite_support/cc/task/vision/utils/frame_buffer_common_utils.h" +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +ABSL_FLAG(std::string, model_path, "", + "Absolute path to the '.tflite' object detector model."); +ABSL_FLAG(std::string, image_path, "", + "Absolute path to the image to perform detection on. The image EXIF " + "orientation flag, if any, is NOT taken into account."); +ABSL_FLAG(std::string, output_png, "", + "Absolute path to a file where to draw the detection results on top " + "of the input image. Must have a '.png' extension."); +ABSL_FLAG(int32, max_results, 5, + "Maximum number of detection results to display."); +ABSL_FLAG( + float, score_threshold, std::numeric_limits::lowest(), + "Detection results with a confidence score below this value are " + "rejected. If specified, overrides the score threshold(s) provided in the " + "TFLite Model Metadata. Ignored otherwise."); +ABSL_FLAG( + std::vector, class_name_whitelist, {}, + "Comma-separated list of class names that acts as a whitelist. If " + "non-empty, detections results whose 'class_name' is not in this list " + "are filtered out. Mutually exclusive with 'class_name_blacklist'."); +ABSL_FLAG(std::vector, class_name_blacklist, {}, + "Comma-separated list of class names that acts as a blacklist. If " + "non-empty, detections results whose 'class_name' is in this list " + "are filtered out. Mutually exclusive with 'class_name_whitelist'."); + +namespace tflite { +namespace support { +namespace task { +namespace vision { + +namespace { +// The line thickness (in pixels) for drawing the detection results. +constexpr int kLineThickness = 3; + +// The number of colors used for drawing the detection results. +constexpr int kColorMapSize = 10; + +// The names of the colors used for drawing the detection results. +constexpr std::array kColorMapNames = { + "red", "green", "blue", "yellow", "fuschia", + "dark red", "dark green", "dark blue", "gray", "black"}; + +// The colors used for drawing the detection results as a flattened array of +// {R,G,B} components. +constexpr uint8 kColorMapComponents[30] = { + 255, 0, 0, 0, 255, 0, 0, 0, 255, 255, 255, 0, 255, 0, 255, + 128, 0, 0, 0, 128, 0, 0, 0, 128, 128, 128, 128, 0, 0, 0}; +} // namespace + +ObjectDetectorOptions BuildOptions() { + ObjectDetectorOptions options; + options.mutable_model_file_with_metadata()->set_file_name( + absl::GetFlag(FLAGS_model_path)); + options.set_max_results(absl::GetFlag(FLAGS_max_results)); + if (absl::GetFlag(FLAGS_score_threshold) > + std::numeric_limits::lowest()) { + options.set_score_threshold(absl::GetFlag(FLAGS_score_threshold)); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_whitelist)) { + options.add_class_name_whitelist(class_name); + } + for (const std::string& class_name : + absl::GetFlag(FLAGS_class_name_blacklist)) { + options.add_class_name_blacklist(class_name); + } + return options; +} + +absl::Status EncodeResultToPngFile(const DetectionResult& result, + const RgbImageData* image) { + for (int index = 0; index < result.detections_size(); ++index) { + // Get bounding box as left, top, right, bottom. + const BoundingBox& box = result.detections(index).bounding_box(); + const int left = box.origin_x(); + const int top = box.origin_y(); + const int right = box.origin_x() + box.width(); + const int bottom = box.origin_y() + box.height(); + // Get color components. + const uint8 r = kColorMapComponents[3 * (index % kColorMapSize)]; + const uint8 g = kColorMapComponents[3 * (index % kColorMapSize) + 1]; + const uint8 b = kColorMapComponents[3 * (index % kColorMapSize) + 2]; + // Draw. Boxes might have coordinates outside of [0, w( x [0, h( so clamping + // is applied. + for (int y = std::max(0, top); y < std::min(image->height, bottom); ++y) { + for (int x = std::max(0, left); x < std::min(image->width, right); ++x) { + int pixel_index = 3 * (image->width * y + x); + if (x < left + kLineThickness || x > right - kLineThickness || + y < top + kLineThickness || y > bottom - kLineThickness) { + image->pixel_data[pixel_index] = r; + image->pixel_data[pixel_index + 1] = g; + image->pixel_data[pixel_index + 2] = b; + } + } + } + } + // Encode to PNG and return. + RETURN_IF_ERROR( + EncodeRgbImageToPngFile(*image, absl::GetFlag(FLAGS_output_png))); + std::cout << absl::StrFormat("Results saved to: %s\n", + absl::GetFlag(FLAGS_output_png)); + return absl::OkStatus(); +} + +void DisplayResult(const DetectionResult& result) { + std::cout << "Results:\n"; + for (int index = 0; index < result.detections_size(); ++index) { + std::cout << absl::StrFormat(" Detection #%d (%s):\n", index, + kColorMapNames[index % kColorMapSize]); + const Detection& detection = result.detections(index); + const BoundingBox& box = detection.bounding_box(); + std::cout << absl::StrFormat(" Box: (x: %d, y: %d, w: %d, h: %d)\n", + box.origin_x(), box.origin_y(), box.width(), + box.height()); + if (detection.classes_size() == 0) { + std::cout << " No top-1 class available"; + } else { + std::cout << " Top-1 class:\n"; + const Class& classification = detection.classes(0); + std::cout << absl::StrFormat(" index : %d\n", + classification.index()); + std::cout << absl::StrFormat(" score : %.5f\n", + classification.score()); + if (classification.has_class_name()) { + std::cout << absl::StrFormat(" class name : %s\n", + classification.class_name()); + } + if (classification.has_display_name()) { + std::cout << absl::StrFormat(" display name: %s\n", + classification.display_name()); + } + } + } +} + +absl::Status Detect() { + // Build ObjectDetector. + const ObjectDetectorOptions& options = BuildOptions(); + ASSIGN_OR_RETURN(std::unique_ptr object_detector, + ObjectDetector::CreateFromOptions(options)); + + // Load image in a FrameBuffer. + ASSIGN_OR_RETURN(RgbImageData image, + DecodeImageFromFile(absl::GetFlag(FLAGS_image_path))); + std::unique_ptr frame_buffer = + CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height}); + + // Run object detection and draw results on input image. + ASSIGN_OR_RETURN(DetectionResult result, + object_detector->Detect(*frame_buffer)); + RETURN_IF_ERROR(EncodeResultToPngFile(result, &image)); + + // Display results as text. + DisplayResult(result); + + // Cleanup and return. + RgbImageDataFree(&image); + return absl::OkStatus(); +} + +} // namespace vision +} // namespace task +} // namespace support +} // namespace tflite + +int main(int argc, char** argv) { + // Parse command line arguments and perform sanity checks. + absl::ParseCommandLine(argc, argv); + if (absl::GetFlag(FLAGS_model_path).empty()) { + std::cerr << "Missing mandatory 'model_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_image_path).empty()) { + std::cerr << "Missing mandatory 'image_path' argument.\n"; + return 1; + } + if (absl::GetFlag(FLAGS_output_png).empty()) { + std::cerr << "Missing mandatory 'output_png' argument.\n"; + return 1; + } + if (!absl::EndsWithIgnoreCase(absl::GetFlag(FLAGS_output_png), ".png")) { + std::cerr << "Argument 'output_png' must end with '.png' or '.PNG'\n"; + return 1; + } + if (!absl::GetFlag(FLAGS_class_name_whitelist).empty() && + !absl::GetFlag(FLAGS_class_name_blacklist).empty()) { + std::cerr << "'class_name_whitelist' and 'class_name_blacklist' arguments " + "are mutually exclusive.\n"; + return 1; + } + + // We need to call this to set up global state for Tensorflow, which is used + // internally for decoding various image formats (JPEG, PNG, etc). + tensorflow::port::InitMain(argv[0], &argc, &argv); + + // Run detection. + absl::Status status = tflite::support::task::vision::Detect(); + if (status.ok()) { + return 0; + } else { + std::cerr << "Detection failed: " << status.message() << "\n"; + return 1; + } +} diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD new file mode 100644 index 000000000..5970664d2 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD @@ -0,0 +1,29 @@ +package( + default_visibility = [ + "//tensorflow_lite_support:users", + ], + licenses = ["notice"], # Apache 2.0 +) + +cc_library( + name = "image_utils", + srcs = ["image_utils.cc"], + hdrs = ["image_utils.h"], + deps = [ + "//tensorflow_lite_support/cc/port:integral_types", + "//tensorflow_lite_support/cc/port:status_macros", + "//tensorflow_lite_support/cc/port:statusor", + "//tensorflow_lite_support/cc/task/core:external_file_handler", + "//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@org_tensorflow//tensorflow/cc:cc_ops", + "@org_tensorflow//tensorflow/cc:scope", + "@org_tensorflow//tensorflow/core:core_cpu", + "@org_tensorflow//tensorflow/core:framework", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core:protos_all_cc", + "@org_tensorflow//tensorflow/core:tensorflow", + ], +) diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc new file mode 100644 index 000000000..5e538e98f --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.cc @@ -0,0 +1,215 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_format.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/cc/ops/array_ops.h" +#include "tensorflow/cc/ops/image_ops.h" +#include "tensorflow/cc/ops/io_ops.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow_lite_support/cc/port/status_macros.h" +#include "tensorflow_lite_support/cc/task/core/external_file_handler.h" +#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h" + +namespace tflite { +namespace support { +namespace task { +namespace vision { +namespace { + +using ::tensorflow::Tensor; +using ::tensorflow::tstring; +using ::tensorflow::ops::DecodeBmp; +using ::tensorflow::ops::DecodeGif; +using ::tensorflow::ops::DecodeJpeg; +using ::tensorflow::ops::DecodePng; +using ::tensorflow::ops::Placeholder; +using ::tensorflow::ops::Squeeze; +using ::tflite::support::task::core::ExternalFile; +using ::tflite::support::task::core::ExternalFileHandler; + +absl::Status ReadEntireFile(absl::string_view file_name, Tensor* output) { + ExternalFile external_file; + external_file.set_file_name(std::string(file_name)); + ASSIGN_OR_RETURN(std::unique_ptr handler, + ExternalFileHandler::CreateFromExternalFile(&external_file)); + output->scalar()() = tstring(handler->GetFileContent()); + return absl::OkStatus(); +} + +} // namespace + +// Core TensorFlow is used for convenience as it provides Ops able to decode +// various image formats. Any other image processing library like OpenCV or +// ImageMagick could be used as an alternative. +StatusOr DecodeImageFromFile(absl::string_view file_name) { + // Read file_name into a tensor named input. + Tensor input(tensorflow::DT_STRING, tensorflow::TensorShape()); + RETURN_IF_ERROR(ReadEntireFile(file_name, &input)); + + auto root = tensorflow::Scope::NewRootScope(); + + // Use a placeholder to read input data. + auto file_reader = + Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_STRING); + + // Try to figure out what kind of file it is and decode it. + const int wanted_channels = 3; // for RGB output image + tensorflow::Output image_reader; + if (absl::EndsWithIgnoreCase(file_name, ".png")) { + image_reader = DecodePng(root.WithOpName("image_reader"), file_reader, + DecodePng::Channels(wanted_channels)); + } else if (absl::EndsWithIgnoreCase(file_name, ".gif")) { + // GIF decoder returns 4-D tensor, remove the first dim. + image_reader = Squeeze( + root.WithOpName("image_reader"), + DecodeGif(root.WithOpName("image_reader_before_squeeze"), file_reader)); + } else if (absl::EndsWithIgnoreCase(file_name, ".bmp")) { + image_reader = DecodeBmp(root.WithOpName("image_reader"), file_reader); + } else if (absl::EndsWithIgnoreCase(file_name, ".jpeg") || + absl::EndsWithIgnoreCase(file_name, ".jpg")) { + image_reader = DecodeJpeg(root.WithOpName("image_reader"), file_reader, + DecodeJpeg::Channels(wanted_channels)); + } else { + return absl::UnimplementedError( + "Only .png, .gif, .bmp and .jpg (or .jpeg) images are supported"); + } + + // This runs the GraphDef network definition constructed above, and returns + // the results in an output tensor. + tensorflow::GraphDef graph; + if (!root.ToGraphDef(&graph).ok()) { + return absl::InternalError( + "Initialization error while decoding input image."); + } + + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions())); + if (!session->Create(graph).ok()) { + return absl::InternalError( + "Initialization error while decoding input image."); + } + + std::vector> inputs = { + {"input", input}, + }; + std::vector output_tensors; + + tensorflow::Status status = + session->Run({inputs}, /*output_tensor_names=*/{"image_reader"}, + /*target_node_names=*/{}, &output_tensors); + if (!status.ok()) { + return absl::InternalError(absl::StrFormat( + "An internal error occurred while decoding input image: %s", + status.error_message())); + } + + // A single output tensor with shape `[height, width, channels]` where + // `channels=3` is expected. + if (output_tensors.size() != 1 || + output_tensors[0].dtype() != tensorflow::DT_UINT8 || + output_tensors[0].dims() != 3 || + output_tensors[0].shape().dim_size(2) != 3) { + return absl::InternalError("Unexpected output after decoding input image."); + } + + RgbImageData image_data; + size_t total_bytes = output_tensors[0].NumElements() * sizeof(uint8); + image_data.pixel_data = static_cast(malloc(total_bytes)); + memcpy(image_data.pixel_data, output_tensors[0].flat().data(), + total_bytes); + image_data.height = output_tensors[0].shape().dim_size(0); + image_data.width = output_tensors[0].shape().dim_size(1); + + return image_data; +} + +absl::Status EncodeRgbImageToPngFile(const RgbImageData& image_data, + absl::string_view image_path) { + // Sanity check inputs. + if (image_data.width <= 0 || image_data.height <= 0) { + return absl::InvalidArgumentError( + absl::StrFormat("Expected positive image dimensions, found %d x %d.", + image_data.width, image_data.height)); + } + if (image_data.pixel_data == nullptr) { + return absl::InvalidArgumentError( + "Expected pixel data to be set, found nullptr."); + } + + // Prepare input tensor. + Tensor input(tensorflow::DataType::DT_UINT8, + {image_data.height, image_data.width, 3}); + auto data = input.flat().data(); + for (int i = 0; i < input.NumElements(); ++i) { + *data = image_data.pixel_data[i]; + ++data; + } + std::vector> inputs = { + {"input", input}, + }; + // Convert path to tensorflow string. + const tensorflow::string output_path(image_path); + + // Build graph. + auto root = tensorflow::Scope::NewRootScope(); + tensorflow::Output placeholder = + Placeholder(root.WithOpName("input"), tensorflow::DataType::DT_UINT8); + tensorflow::Output png_encoder = + tensorflow::ops::EncodePng(root.WithOpName("pngencoder"), placeholder); + tensorflow::ops::WriteFile file_writer = tensorflow::ops::WriteFile( + root.WithOpName("output"), output_path, png_encoder); + tensorflow::GraphDef graph; + if (!root.ToGraphDef(&graph).ok()) { + return absl::InternalError( + "Initialization error while encoding image to PNG."); + } + + // Create session and run graph. + std::unique_ptr session( + tensorflow::NewSession(tensorflow::SessionOptions())); + if (!session->Create(graph).ok()) { + return absl::InternalError( + "Initialization error while encoding image to PNG."); + } + tensorflow::Status status = + session->Run({inputs}, /*output_tensor_names=*/{}, + /*target_node_names=*/{"output"}, /*outputs=*/nullptr); + if (!status.ok()) { + return absl::InternalError(absl::StrFormat( + "An internal error occurred while encoding image to PNG: %s", + status.error_message())); + } + + return absl::OkStatus(); +} + +void RgbImageDataFree(RgbImageData* image) { free(image->pixel_data); } + +} // namespace vision +} // namespace task +} // namespace support +} // namespace tflite diff --git a/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h new file mode 100644 index 000000000..e62d51aa2 --- /dev/null +++ b/tensorflow_lite_support/examples/task/vision/desktop/utils/image_utils.h @@ -0,0 +1,54 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ +#define TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "tensorflow_lite_support/cc/port/integral_types.h" +#include "tensorflow_lite_support/cc/port/statusor.h" + +namespace tflite { +namespace support { +namespace task { +namespace vision { + +// Interleaved RGB image with pixels stored as a row-major flattened array. +struct RgbImageData { + uint8* pixel_data; + int width; + int height; +}; + +// Decodes image file and returns the corresponding RGB image if no error +// occurred. Supported formats are JPEG, PNG, GIF and BMP. If decoding +// succeeded, the caller must manage deletion of the underlying pixel data using +// `RgbImageDataFree`. +StatusOr DecodeImageFromFile(absl::string_view file_name); + +// Encodes the image provided as an RgbImageData as lossless PNG to the provided +// path. +absl::Status EncodeRgbImageToPngFile(const RgbImageData& image_data, + absl::string_view image_path); + +// Releases image pixel data memory. +void RgbImageDataFree(RgbImageData* image); + +} // namespace vision +} // namespace task +} // namespace support +} // namespace tflite + +#endif // TENSORFLOW_LITE_SUPPORT_EXAMPLES_TASK_VISION_DESKTOP_UTILS_IMAGE_UTILS_H_