-
Notifications
You must be signed in to change notification settings - Fork 220
Description
After loading a Vision Transformer model (vit_b32_fe) from tfhub with 0.4.1 version of tensorflow-java the first query takes about 30 seconds to start. Reason is the graph is pre-compiled using XLA.
This is a performance issue I want to solve. Using 0.5.0-SNAPSHOT is not possible because it fails to load the SavedModel (see #472). Since I can load and run the model without issues using python tensorflow 2.7.1 or 2.9.0 I wonder why I get this issue in Java.
Are there any options (disable XLA jit or alike) that may help to avoid that the model is blocking for 30 seconds. Thanks for any help.
System information
-
Have I written custom code (as opposed to using a stock example script provided in TensorFlow):
-
OS Platform and Distribution (e.g., Linux Ubuntu 16.04 x86_64):
Ubuntu 22.04 on amd64 (12th Gen Intel(R) Core(TM) i7-1270P)
- TensorFlow installed from (source or binary):
Using tensorflow-core-platform 0.4.1 jar
- TensorFlow version (use command below):
2.7.1 (from 0.4.1)
- Java version (i.e., the output of
java -version
):
openjdk version "11.0.16" 2022-07-19
OpenJDK Runtime Environment (build 11.0.16+8-post-Ubuntu-0ubuntu122.04)
OpenJDK 64-Bit Server VM (build 11.0.16+8-post-Ubuntu-0ubuntu122.04, mixed mode, sharing)
- Java command line flags (e.g., GC parameters):
- Python version (if transferring a model trained in Python):
- Bazel version (if compiling from source):
- GCC/Compiler version (if compiling from source):
- CUDA/cuDNN version:
- GPU model and memory:
Describe the current behavior
Describe the expected behavior
Code to reproduce the issue
Provide a reproducible test case that is the bare minimum necessary to generate the problem.
try (SavedModelBundle savedModel = SavedModelBundle.loader(graphFile.toString()).withTags(new String[]{"serve"}).load()) {
// warm up query
long start = System.currentTimeMillis();
log.info("Doing warm up query with tensorflow model");
try (TFloat32 xTensor = TFloat32.tensorOf(NdArrays.ofFloats(Shape.of(1,244,244,3)));
TFloat32 zTensor = (TFloat32) savedModel
.call(Collections.singletonMap("inputs", xTensor))
.get("output_0")) {
long end = System.currentTimeMillis();
log.info("Successfully warmed up tensorflow model, took "+(end-start)+"ms");
}
};
warmUpThread.start();
Other info / logs
2022-09-14 12:14:44.843292: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /tmp/pxl_14542487053289680376
2022-09-14 12:14:44.888522: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:107] Reading meta graph with tags { serve }
2022-09-14 12:14:44.888626: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:148] Reading SavedModel debug info (if present) from: /tmp/pxl_14542487053289680376
2022-09-14 12:14:44.888682: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-14 12:14:45.050174: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-09-14 12:14:45.593306: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: /tmp/pxl_14542487053289680376
2022-09-14 12:14:45.877830: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 1034540 microseconds.
sɛt 14, 2022 7:44:46 PM de.pixolution.process.module.tf2.SavedModelEmbeddings$1 run
INFO: Doing warm up query with tensorflow model
2022-09-14 12:14:47.213119: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7f0de8014540 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-09-14 12:14:47.213143: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179] StreamExecutor device (0): Host, Default Version
2022-09-14 12:14:47.255383: I external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:237] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2022-09-14 12:15:18.095626: I external/org_tensorflow/tensorflow/compiler/jit/xla_compilation_cache.cc:351] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
sɛt 14, 2022 7:45:18 PM de.pixolution.process.module.tf2.SavedModelEmbeddings$1 run
INFO: Successfully warmed up tensorflow model, took 31899ms