Skip to content

Commit

Permalink
hooks: tensorflow: collect plugins from tensorflow-plugins
Browse files Browse the repository at this point in the history
Have the `tensorflow` standard hook collect binaries from the
`tensorflow-plugins` package; this contains plugins for tensorflow's
pluggable device architecture (such as `tensorflow-metal` for macOS
and `tensorflow-directml-plugin` for Windows).

Have the `tensorflow` run-time hook override the `site.getsitepackages()`
with custom implementation that allows us to trick `tensorflow` into
loading the plugins.
  • Loading branch information
rokm committed Dec 23, 2023
1 parent 4ac0e97 commit a2f65ef
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 8 deletions.
5 changes: 5 additions & 0 deletions news/676.update.7.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Extend ``tensorflow`` hook to collect plugins installed in the
``tensorflow-plugins`` directory/package. Have the run-time ``tensorflow``
hook provide an override for ``site.getsitepackages()`` that allows us
to work around a broken module file location check and trick ``tensorflow``
into loading the collected plugins.
50 changes: 42 additions & 8 deletions src/_pyinstaller_hooks_contrib/hooks/rthooks/pyi_rth_tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,45 @@
# SPDX-License-Identifier: Apache-2.0
#-----------------------------------------------------------------------------

# `tensorflow` versions prior to 2.3.0 attempt to use `site.USER_SITE` in path/string manipulation functions.
# As frozen application runs with disabled `site`, the value of this variable is `None`, and causes path/string
# manipulation functions to raise an error. As a work-around, we set `site.USER_SITE` to an empty string, which is
# also what the fake `site` module available in PyInstaller prior to v5.5 did.
import site

if site.USER_SITE is None:
site.USER_SITE = ''
def _pyi_rthook():
import sys

# `tensorflow` versions prior to 2.3.0 attempt to use `site.USER_SITE` in path/string manipulation functions.
# As frozen application runs with disabled `site`, the value of this variable is `None`, and causes path/string
# manipulation functions to raise an error. As a work-around, we set `site.USER_SITE` to an empty string, which is
# also what the fake `site` module available in PyInstaller prior to v5.5 did.
import site

if site.USER_SITE is None:
site.USER_SITE = ''

# The issue described about with site.USER_SITE being None has largely been resolved in contemporary `tensorflow`
# versions, which now check that `site.ENABLE_USER_SITE` is set and that `site.USER_SITE` is not None before
# trying to use it.
#
# However, `tensorflow` will attempt to search and load its plugins only if it believes that it is running from
# "a pip-based installation" - if the package's location is rooted in one of the "site-packages" directories. See
# https://github.com/tensorflow/tensorflow/blob/6887368d6d46223f460358323c4b76d61d1558a8/tensorflow/api_template.__init__.py#L110C76-L156
# Unfortunately, they "cleverly" infer the module's location via `inspect.getfile(inspect.currentframe())`, which
# in the frozen application returns anonymized relative source file name (`tensorflow/__init__.py`) - so we need one
# of the "site directories" to be just "tensorflow" (to fool the `_running_from_pip_package()` check), and we also
# need `sys._MEIPASS` to be among them (to load the plugins from the actual `sys._MEIPASS/tensorflow-plugins`).
# Therefore, we monkey-patch `site.getsitepackages` to add those two entries to the list of "site directories".

_orig_getsitepackages = getattr(site, 'getsitepackages')

def _pyi_getsitepackages():
return [
sys._MEIPASS,
"tensorflow",
*(_orig_getsitepackages() if _orig_getsitepackages is not None else []),
]

site.getsitepackages = _pyi_getsitepackages

# NOTE: instead of the above override, we could also set TF_PLUGGABLE_DEVICE_LIBRARY_PATH, but that works only
# for tensorflow >= 2.12.


_pyi_rthook()
del _pyi_rthook
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from PyInstaller.compat import is_linux
from PyInstaller.utils.hooks import (
collect_data_files,
collect_dynamic_libs,
collect_submodules,
get_module_attribute,
is_module_satisfies,
Expand Down Expand Up @@ -125,6 +126,7 @@ def _submodules_filter(x):
if version >= Version("2.14.0"):
hiddenimports += ['ml_dtypes']

binaries = []
excludedimports = excluded_submodules

# Suppress warnings for missing hidden imports generated by this hook.
Expand Down Expand Up @@ -165,3 +167,8 @@ def _infer_nvidia_hiddenimports():
nvidia_hiddenimports = []
logger.info("hook-tensorflow: inferred hidden imports for CUDA libraries: %r", nvidia_hiddenimports)
hiddenimports += nvidia_hiddenimports


# Collect the tensorflow-plugins (pluggable device plugins)
hiddenimports += ['tensorflow-plugins']
binaries += collect_dynamic_libs('tensorflow-plugins')

0 comments on commit a2f65ef

Please sign in to comment.