Skip to content

Commit

Permalink
Use stb_image for image encoding and decoding instead of relying on t…
Browse files Browse the repository at this point in the history
…he TensorFlow DecodeJpg/EncodePng/etc Ops.

This drastically reduces the compilation time for the Vision Task CLI demo tools.

PiperOrigin-RevId: 323527482
  • Loading branch information
tensorflower-gardener authored and tflite-support-robot committed Jul 28, 2020
1 parent bbef260 commit feba8f4
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 219 deletions.
7 changes: 7 additions & 0 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ http_archive(
build_file = "//third_party:libyuv.BUILD",
)

http_archive(
name = "stblib",
strip_prefix = "stb-master",
urls = ["https://github.com/nothings/stb/archive/master.zip"],
build_file = "//third_party:stblib.BUILD",
)


load("//third_party/flatbuffers:workspace.bzl", flatbuffers = "repo")

Expand Down
3 changes: 0 additions & 3 deletions tensorflow_lite_support/examples/task/vision/desktop/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ cc_binary(
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)

Expand All @@ -45,7 +44,6 @@ cc_binary(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)

Expand All @@ -66,6 +64,5 @@ cc_binary(
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/core:lib",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ limitations under the License.
#include "absl/flags/parse.h"
#include "absl/status/status.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -40,7 +39,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' image classifier model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to classify. The image EXIF orientation "
"Absolute path to the image to classify. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
ABSL_FLAG(int32, max_results, 5,
"Maximum number of classification results to display.");
Expand Down Expand Up @@ -116,18 +116,28 @@ absl::Status Classify() {
ImageClassifier::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run classification and display results.
ASSIGN_OR_RETURN(ClassificationResult result,
image_classifier->Classify(*frame_buffer));
DisplayResult(result);

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand All @@ -154,10 +164,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run classification.
absl::Status status = tflite::support::task::vision::Classify();
if (status.ok()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -41,7 +40,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' image segmenter model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to segment. The image EXIF orientation "
"Absolute path to the image to segment. The image must be RGB or "
"RGBA (grayscale is not supported). The image EXIF orientation "
"flag, if any, is NOT taken into account.");
ABSL_FLAG(std::string, output_mask_png, "",
"Absolute path to the output category mask (confidence masks outputs "
Expand Down Expand Up @@ -75,9 +75,10 @@ absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {
// Create RgbImageData for the output mask.
uint8* pixel_data = static_cast<uint8*>(
malloc(segmentation.width() * segmentation.height() * 3 * sizeof(uint8)));
RgbImageData mask = {.pixel_data = pixel_data,
.width = segmentation.width(),
.height = segmentation.height()};
ImageData mask = {.pixel_data = pixel_data,
.width = segmentation.width(),
.height = segmentation.height(),
.channels = 3};

// Populate RgbImageData from the raw mask and ColoredLabel-s.
for (int i = 0; i < segmentation.width() * segmentation.height(); ++i) {
Expand All @@ -90,12 +91,12 @@ absl::Status EncodeMaskToPngFile(const SegmentationResult& result) {

// Encode mask as PNG.
RETURN_IF_ERROR(
EncodeRgbImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
EncodeImageToPngFile(mask, absl::GetFlag(FLAGS_output_mask_png)));
std::cout << absl::StrFormat("Category mask saved to: %s\n",
absl::GetFlag(FLAGS_output_mask_png));

// Cleanup and return.
RgbImageDataFree(&mask);
ImageDataFree(&mask);
return absl::OkStatus();
}

Expand Down Expand Up @@ -138,10 +139,20 @@ absl::Status Segment() {
ImageSegmenter::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run segmentation and save category mask.
ASSIGN_OR_RETURN(SegmentationResult result,
Expand All @@ -152,7 +163,7 @@ absl::Status Segment() {
RETURN_IF_ERROR(DisplayColorLegend(result));

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand Down Expand Up @@ -181,10 +192,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run segmentation.
absl::Status status = tflite::support::task::vision::Segment();
if (status.ok()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/strings/match.h"
#include "absl/strings/str_format.h"
#include "tensorflow/core/platform/init_main.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/external_file_handler.h"
#include "tensorflow_lite_support/cc/task/core/proto/external_file_proto_inc.h"
Expand All @@ -44,7 +43,8 @@ limitations under the License.
ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' object detector model.");
ABSL_FLAG(std::string, image_path, "",
"Absolute path to the image to perform detection on. The image EXIF "
"Absolute path to the image to run detection on. The image must be "
"RGB or RGBA (grayscale is not supported). The image EXIF "
"orientation flag, if any, is NOT taken into account.");
ABSL_FLAG(std::string, output_png, "",
"Absolute path to a file where to draw the detection results on top "
Expand Down Expand Up @@ -111,7 +111,7 @@ ObjectDetectorOptions BuildOptions() {
}

absl::Status EncodeResultToPngFile(const DetectionResult& result,
const RgbImageData* image) {
const ImageData* image) {
for (int index = 0; index < result.detections_size(); ++index) {
// Get bounding box as left, top, right, bottom.
const BoundingBox& box = result.detections(index).bounding_box();
Expand All @@ -127,7 +127,7 @@ absl::Status EncodeResultToPngFile(const DetectionResult& result,
// is applied.
for (int y = std::max(0, top); y < std::min(image->height, bottom); ++y) {
for (int x = std::max(0, left); x < std::min(image->width, right); ++x) {
int pixel_index = 3 * (image->width * y + x);
int pixel_index = image->channels * (image->width * y + x);
if (x < left + kLineThickness || x > right - kLineThickness ||
y < top + kLineThickness || y > bottom - kLineThickness) {
image->pixel_data[pixel_index] = r;
Expand All @@ -139,7 +139,7 @@ absl::Status EncodeResultToPngFile(const DetectionResult& result,
}
// Encode to PNG and return.
RETURN_IF_ERROR(
EncodeRgbImageToPngFile(*image, absl::GetFlag(FLAGS_output_png)));
EncodeImageToPngFile(*image, absl::GetFlag(FLAGS_output_png)));
std::cout << absl::StrFormat("Results saved to: %s\n",
absl::GetFlag(FLAGS_output_png));
return absl::OkStatus();
Expand Down Expand Up @@ -183,10 +183,20 @@ absl::Status Detect() {
ObjectDetector::CreateFromOptions(options));

// Load image in a FrameBuffer.
ASSIGN_OR_RETURN(RgbImageData image,
ASSIGN_OR_RETURN(ImageData image,
DecodeImageFromFile(absl::GetFlag(FLAGS_image_path)));
std::unique_ptr<FrameBuffer> frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
std::unique_ptr<FrameBuffer> frame_buffer;
if (image.channels == 3) {
frame_buffer =
CreateFromRgbRawBuffer(image.pixel_data, {image.width, image.height});
} else if (image.channels == 4) {
frame_buffer =
CreateFromRgbaRawBuffer(image.pixel_data, {image.width, image.height});
} else {
return absl::InvalidArgumentError(absl::StrFormat(
"Expected image with 3 (RGB) or 4 (RGBA) channels, found %d",
image.channels));
}

// Run object detection and draw results on input image.
ASSIGN_OR_RETURN(DetectionResult result,
Expand All @@ -197,7 +207,7 @@ absl::Status Detect() {
DisplayResult(result);

// Cleanup and return.
RgbImageDataFree(&image);
ImageDataFree(&image);
return absl::OkStatus();
}

Expand Down Expand Up @@ -232,10 +242,6 @@ int main(int argc, char** argv) {
return 1;
}

// We need to call this to set up global state for Tensorflow, which is used
// internally for decoding various image formats (JPEG, PNG, etc).
tensorflow::port::InitMain(argv[0], &argc, &argv);

// Run detection.
absl::Status status = tflite::support::task::vision::Detect();
if (status.ok()) {
Expand Down
11 changes: 2 additions & 9 deletions tensorflow_lite_support/examples/task/vision/desktop/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,10 @@ cc_library(
"//tensorflow_lite_support/cc/port:integral_types",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/port:statusor",
"//tensorflow_lite_support/cc/task/core:external_file_handler",
"//tensorflow_lite_support/cc/task/core/proto:external_file_proto_inc",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@org_tensorflow//tensorflow/cc:cc_ops",
"@org_tensorflow//tensorflow/cc:scope",
"@org_tensorflow//tensorflow/core:core_cpu",
"@org_tensorflow//tensorflow/core:framework",
"@org_tensorflow//tensorflow/core:lib",
"@org_tensorflow//tensorflow/core:protos_all_cc",
"@org_tensorflow//tensorflow/core:tensorflow",
"@stblib//:stb_image",
"@stblib//:stb_image_write",
],
)
Loading

0 comments on commit feba8f4

Please sign in to comment.