Skip to content

Commit

Permalink
Add support for local wheel files in hermetic python
Browse files Browse the repository at this point in the history
This allows users to specify a list of workspaces that contain pre-built local wheels without need to manually add them in requirements.txt files.

The wheels will be automatically processed by bazel rules and injected into the requirements_lock_<py_version>.txt on the fly (assuming `HERMETIC_PYTHON_VERSION=py_version`).

This feature is mainly inspired by pytorch/xla demand, since building pytorch/xla implies first building pytorch repo locally and then pointing to its artifacts (both raw .so files and entire .whl) in pytorch/xla build.

This also helps JAX to facilitate build_jaxlib=false case, as it would eliminate need to manually update requirements_locak.txt files in JAX CI as well.

PiperOrigin-RevId: 636691616
  • Loading branch information
vam-google authored and tensorflower-gardener committed May 23, 2024
1 parent 14f88cc commit 36f7c31
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 12 deletions.
4 changes: 2 additions & 2 deletions third_party/py/python_init_pip.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Hermetic Python initialization. Consult the WORKSPACE on how to use it."""

load("@python//:defs.bzl", "interpreter")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS")
load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse")

def python_init_pip():
Expand Down Expand Up @@ -30,5 +30,5 @@ cc_library(
name = "pypi",
annotations = numpy_annotations,
python_interpreter_target = interpreter,
requirements_lock = REQUIREMENTS,
requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS,
)
7 changes: 6 additions & 1 deletion third_party/py/python_init_repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
load("@rules_python//python:repositories.bzl", "py_repositories")
load("//third_party/py:python_repo.bzl", "python_repository")

def python_init_repositories(requirements = {}):
def python_init_repositories(
requirements = {},
local_wheel_workspaces = [],
local_wheel_dist_folder = None):
python_repository(
name = "python_version_repo",
requirements_versions = requirements.keys(),
requirements_locks = requirements.values(),
local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder,
)
py_repositories()
79 changes: 78 additions & 1 deletion third_party/py/python_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ HERMETIC_PYTHON_VERSION = "{version}"
WHEEL_NAME = "{wheel_name}"
WHEEL_COLLAB = "{wheel_collab}"
REQUIREMENTS = "{requirements}"
REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}"
"""

def _python_repository_impl(ctx):
Expand All @@ -43,22 +44,90 @@ def _python_repository_impl(ctx):
else:
print("Using hermetic Python %s" % version) # buildifier: disable=print

requirements = ""
requirements = None
for i in range(0, len(ctx.attr.requirements_locks)):
if ctx.attr.requirements_versions[i] == version:
requirements = ctx.attr.requirements_locks[i]
break

if not requirements:
fail("""
Could not find requirements_lock.txt file matching specified Python version.
Specified python version: {version}
Python versions with available requirement_lock.txt files: {versions}
Please check python_init_repositories() in your WORKSPACE file.
""".format(
version = version,
versions = ", ".join(ctx.attr.requirements_versions),
))

requirements_with_local_wheels = str(requirements)
if ctx.attr.local_wheel_workspaces:
local_wheel_requirements = _get_injected_local_wheels(
ctx,
version,
ctx.attr.local_wheel_workspaces,
)
requirements_content = [ctx.read(requirements)] + local_wheel_requirements
merged_requirements_content = "\n".join(requirements_content)
requirements_with_local_wheels = requirements_with_local_wheels.replace(
"@" + requirements.repo_name,
"@" + ctx.name,
)

ctx.file(
requirements.name,
merged_requirements_content,
)

ctx.file(
"py_version.bzl",
content.format(
version = version,
wheel_name = wheel_name,
wheel_collab = wheel_collab,
requirements = str(requirements),
requirements_with_local_wheels = requirements_with_local_wheels,
),
)

def _get_injected_local_wheels(ctx, py_version, local_wheel_workspaces):
local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "")
wheels = {}

for local_wheel_workspace in local_wheel_workspaces:
local_wheel_workspace_path = ctx.path(local_wheel_workspace)
dist_folder = ctx.attr.local_wheel_dist_folder
dist_wheels = local_wheel_workspace_path.dirname.get_child(dist_folder).readdir()

for wheel in dist_wheels:
bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0:
continue

name_components = bn.split("-")
package_name = name_components[0]
for name_component in name_components[1:]:
if name_component[0].isdigit():
break
package_name += "-" + name_component

latest_wheel = wheels.get(package_name, None)

if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel

for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append(
"{wheel_name} @ file://{wheel_path}".format(
wheel_name = wheel_name,
wheel_path = wheel_path.realpath,
),
)

return local_wheel_requirements

python_repository = repository_rule(
implementation = _python_repository_impl,
attrs = {
Expand All @@ -70,6 +139,14 @@ python_repository = repository_rule(
mandatory = False,
default = [],
),
"local_wheel_workspaces": attr.label_list(
mandatory = False,
default = [],
),
"local_wheel_dist_folder": attr.string(
mandatory = False,
default = "dist",
),
},
environ = [
"TF_PYTHON_VERSION",
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/third_party/py/python_init_pip.bzl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Hermetic Python initialization. Consult the WORKSPACE on how to use it."""

load("@python//:defs.bzl", "interpreter")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS")
load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse")

def python_init_pip():
Expand Down Expand Up @@ -30,5 +30,5 @@ cc_library(
name = "pypi",
annotations = numpy_annotations,
python_interpreter_target = interpreter,
requirements_lock = REQUIREMENTS,
requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS,
)
7 changes: 6 additions & 1 deletion third_party/xla/third_party/py/python_init_repositories.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
load("@rules_python//python:repositories.bzl", "py_repositories")
load("//third_party/py:python_repo.bzl", "python_repository")

def python_init_repositories(requirements = {}):
def python_init_repositories(
requirements = {},
local_wheel_workspaces = [],
local_wheel_dist_folder = None):
python_repository(
name = "python_version_repo",
requirements_versions = requirements.keys(),
requirements_locks = requirements.values(),
local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder,
)
py_repositories()
79 changes: 78 additions & 1 deletion third_party/xla/third_party/py/python_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ HERMETIC_PYTHON_VERSION = "{version}"
WHEEL_NAME = "{wheel_name}"
WHEEL_COLLAB = "{wheel_collab}"
REQUIREMENTS = "{requirements}"
REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}"
"""

def _python_repository_impl(ctx):
Expand All @@ -43,22 +44,90 @@ def _python_repository_impl(ctx):
else:
print("Using hermetic Python %s" % version) # buildifier: disable=print

requirements = ""
requirements = None
for i in range(0, len(ctx.attr.requirements_locks)):
if ctx.attr.requirements_versions[i] == version:
requirements = ctx.attr.requirements_locks[i]
break

if not requirements:
fail("""
Could not find requirements_lock.txt file matching specified Python version.
Specified python version: {version}
Python versions with available requirement_lock.txt files: {versions}
Please check python_init_repositories() in your WORKSPACE file.
""".format(
version = version,
versions = ", ".join(ctx.attr.requirements_versions),
))

requirements_with_local_wheels = str(requirements)
if ctx.attr.local_wheel_workspaces:
local_wheel_requirements = _get_injected_local_wheels(
ctx,
version,
ctx.attr.local_wheel_workspaces,
)
requirements_content = [ctx.read(requirements)] + local_wheel_requirements
merged_requirements_content = "\n".join(requirements_content)
requirements_with_local_wheels = requirements_with_local_wheels.replace(
"@" + requirements.repo_name,
"@" + ctx.name,
)

ctx.file(
requirements.name,
merged_requirements_content,
)

ctx.file(
"py_version.bzl",
content.format(
version = version,
wheel_name = wheel_name,
wheel_collab = wheel_collab,
requirements = str(requirements),
requirements_with_local_wheels = requirements_with_local_wheels,
),
)

def _get_injected_local_wheels(ctx, py_version, local_wheel_workspaces):
local_wheel_requirements = []
py_ver_marker = "-cp%s-" % py_version.replace(".", "")
wheels = {}

for local_wheel_workspace in local_wheel_workspaces:
local_wheel_workspace_path = ctx.path(local_wheel_workspace)
dist_folder = ctx.attr.local_wheel_dist_folder
dist_wheels = local_wheel_workspace_path.dirname.get_child(dist_folder).readdir()

for wheel in dist_wheels:
bn = wheel.basename
if not bn.endswith(".whl") or bn.find(py_ver_marker) < 0:
continue

name_components = bn.split("-")
package_name = name_components[0]
for name_component in name_components[1:]:
if name_component[0].isdigit():
break
package_name += "-" + name_component

latest_wheel = wheels.get(package_name, None)

if not latest_wheel or latest_wheel.basename < wheel.basename:
wheels[package_name] = wheel

for wheel_name, wheel_path in wheels.items():
local_wheel_requirements.append(
"{wheel_name} @ file://{wheel_path}".format(
wheel_name = wheel_name,
wheel_path = wheel_path.realpath,
),
)

return local_wheel_requirements

python_repository = repository_rule(
implementation = _python_repository_impl,
attrs = {
Expand All @@ -70,6 +139,14 @@ python_repository = repository_rule(
mandatory = False,
default = [],
),
"local_wheel_workspaces": attr.label_list(
mandatory = False,
default = [],
),
"local_wheel_dist_folder": attr.string(
mandatory = False,
default = "dist",
),
},
environ = [
"TF_PYTHON_VERSION",
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Hermetic Python initialization. Consult the WORKSPACE on how to use it."""

load("@python//:defs.bzl", "interpreter")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS")
load("@python_version_repo//:py_version.bzl", "REQUIREMENTS_WITH_LOCAL_WHEELS")
load("@rules_python//python:pip.bzl", "package_annotation", "pip_parse")

def python_init_pip():
Expand Down Expand Up @@ -30,5 +30,5 @@ cc_library(
name = "pypi",
annotations = numpy_annotations,
python_interpreter_target = interpreter,
requirements_lock = REQUIREMENTS,
requirements_lock = REQUIREMENTS_WITH_LOCAL_WHEELS,
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,15 @@
load("@rules_python//python:repositories.bzl", "py_repositories")
load("//third_party/py:python_repo.bzl", "python_repository")

def python_init_repositories(requirements = {}):
def python_init_repositories(
requirements = {},
local_wheel_workspaces = [],
local_wheel_dist_folder = None):
python_repository(
name = "python_version_repo",
requirements_versions = requirements.keys(),
requirements_locks = requirements.values(),
local_wheel_workspaces = local_wheel_workspaces,
local_wheel_dist_folder = local_wheel_dist_folder,
)
py_repositories()
Loading

0 comments on commit 36f7c31

Please sign in to comment.