Skip to content

Error with JAX backend and 'keras_hub' #1511

@t-kalinowski

Description

@t-kalinowski

On Linux with a GPU, if you configure Keras to use the Jax backend and also attempt to use keras_hub, you might encounter an error like this:

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/keras_hub/__init__.py", line 7, in <module>
    from keras_hub import layers as layers
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/keras_hub/layers/__init__.py", line 8, in <module>
    from keras_hub.src.layers.modeling.anchor_generator import (
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/keras_hub/src/layers/modeling/anchor_generator.py", line 7, in <module>
    from keras_hub.src.utils.tensor_utils import assert_bounding_box_support
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/keras_hub/src/utils/tensor_utils.py", line 15, in <module>
    import tensorflow_text as tf_text
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/__init__.py", line 21, in <module>
    from tensorflow_text.python import keras
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/keras/__init__.py", line 21, in <module>
    from tensorflow_text.python.keras.layers import *
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/keras/layers/__init__.py", line 22, in <module>
    from tensorflow_text.python.keras.layers.tokenization_layers import *
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/keras/layers/tokenization_layers.py", line 24, in <module>
    from tensorflow_text.python.ops import unicode_script_tokenizer
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/ops/__init__.py", line 26, in <module>
    from tensorflow_text.python.ops.boise_offset_converter import boise_tags_to_offsets
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/ops/boise_offset_converter.py", line 32, in <module>
    gen_boise_offset_converter = load_library.load_op_library(resource_loader.get_path_to_datafile('_boise_offset_converter.so'))
                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow/python/framework/load_library.py", line 54, in load_op_library
    lib_handle = py_tf.TF_LoadLibrary(library_filename)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
tensorflow.python.framework.errors_impl.NotFoundError: /home/tomasz/.cache/R/reticulate/uv/cache/archive-v0/MGY4YDAhCyJ7TGcR_S_D9/lib/python3.11/site-packages/tensorflow_text/python/ops/_boise_offset_converter.so: undefined symbol: _ZN6tflite4shim23TfShapeInferenceContextC1EPN10tensorflow15shape_inference16InferenceContextE

This happens because tensorflow_text (a dependency of keras_hub) does not correctly recognize tensorflow-cpu as satisfying the dependency for tensorflow. As a result, both tensorflow and tensorflow-cpu are pulled in as dependencies.

Since both tensorflow and tensorflow-cpu are declared as required packages, there is a race condition in uv when creating an ephemeral venv. uv parallelizes installation, and the order of package installation is non-deterministic. In this situation, the last one installed wins. If tensorflow wins, you see the error. If tensorflow-cpu wins, you don’t see the error.

We avoid pulling in tensorflow for two reasons:
a) The CUDA requirements are often out of sync with Jax.
b) We want to prevent TensorFlow from allocating any memory on the GPU, since we are explicitly using Jax for training and only using tensorflow-text for preprocessing (which should happen on the CPU).

The workaround for now is to re-run the uv venv creation until it succeeds. It usually takes fewer than 5 tries.

library(keras3)
use_backend("jax", gpu = TRUE)
reticulate::py_require("keras-hub")
packages <- reticulate::py_require()$packages

repeat {
  python <- reticulate:::uv_get_or_create_env(packages)
  status <- system2(python, "-c 'import keras_hub'")
  if (status) {
    unlink(dirname(dirname(python)), recursive = TRUE, force = TRUE)
  } else {
    break
  }
}

Once the “correct” venv is created, subsequent R sessions will reuse it from the cache, so this only needs to happen once on initial installation. (If additional py_require() calls are made, you might need to do this again with the updated packages list).

Hopefully, we will have a proper solution once one of these upstream issues is resolved:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions