Skip to content

Commit

Permalink
move wheel extraction for UnpackWheels into PexBuilderWrapper and res…
Browse files Browse the repository at this point in the history
…olve for a single platform only (#7289)

### Problem

Resolves #7245. We are seeing wheels getting unpacked for the wrong platform nondeterministically in our internal repo, where we have multiple platforms in `python-setup.platforms`.

### Solution

- Resolve wheels for `UnpackWheels` for the current platform only.
- Ensure that there is exactly a single wheel resolved for the desired distribution name.

### Result

The complex and unnecessary dist resolution / location process in `UnpackWheels` is a lot smoother, we resolve for only a single platform, and we check to ensure only a single dist can be located from the resolved requirements.
  • Loading branch information
cosmicexplorer committed Mar 1, 2019
1 parent 3c5c70f commit 2c9c338
Show file tree
Hide file tree
Showing 13 changed files with 352 additions and 154 deletions.
1 change: 1 addition & 0 deletions examples/3rdparty/python/BUILD
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ unpacked_whls(
'include/**/*', 'include/**/*',
'./*.so', './*.so',
], ],
within_data_subdir='purelib/tensorflow',
) )
47 changes: 39 additions & 8 deletions src/python/pants/backend/python/subsystems/pex_build_util.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pants.base.exceptions import TaskError from pants.base.exceptions import TaskError
from pants.build_graph.files import Files from pants.build_graph.files import Files
from pants.subsystem.subsystem import Subsystem from pants.subsystem.subsystem import Subsystem
from pants.util.collections import assert_single_element




def is_python_target(tgt): def is_python_target(tgt):
Expand Down Expand Up @@ -148,16 +149,33 @@ def add_requirement_libs_from(self, req_libs, platforms=None):
reqs = [req for req_lib in req_libs for req in req_lib.requirements] reqs = [req for req_lib in req_libs for req in req_lib.requirements]
self.add_resolved_requirements(reqs, platforms=platforms) self.add_resolved_requirements(reqs, platforms=platforms)


def add_resolved_requirements(self, reqs, platforms=None): class SingleDistExtractionError(Exception): pass
"""Multi-platform dependency resolution for PEX files.


:param builder: Dump the requirements into this builder. def extract_single_dist_for_current_platform(self, reqs, dist_key):
:param interpreter: The :class:`PythonInterpreter` to resolve requirements for. """Resolve a specific distribution from a set of requirements matching the current platform.
:param reqs: A list of :class:`PythonRequirement` to resolve.
:param log: Use this logger. :param list reqs: A list of :class:`PythonRequirement` to resolve.
:param platforms: A list of :class:`Platform`s to resolve requirements for. :param str dist_key: The value of `distribution.key` to match for a `distribution` from the
Defaults to the platforms specified by PythonSetup. resolved requirements.
:return: The single :class:`pkg_resources.Distribution` matching `dist_key`.
:raises: :class:`self.SingleDistExtractionError` if no dists or multiple dists matched the given
`dist_key`.
""" """
distributions = self._resolve_distributions_by_platform(reqs, platforms=['current'])
try:
matched_dist = assert_single_element(list(
dist
for _, dists in distributions.items()
for dist in dists
if dist.key == dist_key
))
except (StopIteration, ValueError) as e:
raise self.SingleDistExtractionError(
"Exactly one dist was expected to match name {} in requirements {}: {}"
.format(dist_key, reqs, e))
return matched_dist

def _resolve_distributions_by_platform(self, reqs, platforms):
deduped_reqs = OrderedSet(reqs) deduped_reqs = OrderedSet(reqs)
find_links = OrderedSet() find_links = OrderedSet()
for req in deduped_reqs: for req in deduped_reqs:
Expand All @@ -169,6 +187,19 @@ def add_resolved_requirements(self, reqs, platforms=None):
# Resolve the requirements into distributions. # Resolve the requirements into distributions.
distributions = self._resolve_multi(self._builder.interpreter, deduped_reqs, platforms, distributions = self._resolve_multi(self._builder.interpreter, deduped_reqs, platforms,
find_links) find_links)
return distributions

def add_resolved_requirements(self, reqs, platforms=None):
"""Multi-platform dependency resolution for PEX files.
:param builder: Dump the requirements into this builder.
:param interpreter: The :class:`PythonInterpreter` to resolve requirements for.
:param reqs: A list of :class:`PythonRequirement` to resolve.
:param log: Use this logger.
:param platforms: A list of :class:`Platform`s to resolve requirements for.
Defaults to the platforms specified by PythonSetup.
"""
distributions = self._resolve_distributions_by_platform(reqs, platforms=platforms)
locations = set() locations = set()
for platform, dists in distributions.items(): for platform, dists in distributions.items():
for dist in dists: for dist in dists:
Expand Down
20 changes: 19 additions & 1 deletion src/python/pants/backend/python/targets/unpacked_whls.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
class UnpackedWheels(ImportWheelsMixin, Target): class UnpackedWheels(ImportWheelsMixin, Target):
"""A set of sources extracted from JAR files. """A set of sources extracted from JAR files.
NB: Currently, wheels are always resolved for the 'current' platform.
:API: public :API: public
""" """


Expand All @@ -34,8 +36,10 @@ class ExpectedLibrariesError(Exception):
"""Thrown when the target has no libraries defined.""" """Thrown when the target has no libraries defined."""
pass pass


# TODO: consider introducing some form of source roots instead of the manual `within_data_subdir`
# kwarg!
def __init__(self, module_name, libraries=None, include_patterns=None, exclude_patterns=None, def __init__(self, module_name, libraries=None, include_patterns=None, exclude_patterns=None,
compatibility=None, payload=None, **kwargs): compatibility=None, within_data_subdir=None, payload=None, **kwargs):
""" """
:param str module_name: The name of the specific python module containing headers and/or :param str module_name: The name of the specific python module containing headers and/or
libraries to extract (e.g. 'tensorflow'). libraries to extract (e.g. 'tensorflow').
Expand All @@ -47,6 +51,14 @@ def __init__(self, module_name, libraries=None, include_patterns=None, exclude_p
:param compatibility: Python interpreter constraints used to create the pex for the requirement :param compatibility: Python interpreter constraints used to create the pex for the requirement
target. If unset, the default interpreter constraints are used. This target. If unset, the default interpreter constraints are used. This
argument is unnecessary unless the native code depends on libpython. argument is unnecessary unless the native code depends on libpython.
:param str within_data_subdir: If provided, descend into '<name>-<version>.data/<subdir>' when
matching `include_patterns`. For python wheels which declare any
non-code data, this is usually needed to extract that without
manually specifying the relative path, including the package
version. For example, when `data_files` is used in a setup.py,
`within_data_subdir='data'` will allow specifying
`include_patterns` matching exactly what is specified in the
setup.py.
""" """
payload = payload or Payload() payload = payload or Payload()
payload.add_fields({ payload.add_fields({
Expand All @@ -55,7 +67,9 @@ def __init__(self, module_name, libraries=None, include_patterns=None, exclude_p
'include_patterns' : PrimitiveField(include_patterns or ()), 'include_patterns' : PrimitiveField(include_patterns or ()),
'exclude_patterns' : PrimitiveField(exclude_patterns or ()), 'exclude_patterns' : PrimitiveField(exclude_patterns or ()),
'compatibility': PrimitiveField(maybe_list(compatibility or ())), 'compatibility': PrimitiveField(maybe_list(compatibility or ())),
'within_data_subdir': PrimitiveField(within_data_subdir),
# TODO: consider supporting transitive deps like UnpackedJars! # TODO: consider supporting transitive deps like UnpackedJars!
# TODO: consider supporting `platforms` as in PythonBinary!
}) })
super(UnpackedWheels, self).__init__(payload=payload, **kwargs) super(UnpackedWheels, self).__init__(payload=payload, **kwargs)


Expand All @@ -70,3 +84,7 @@ def module_name(self):
@property @property
def compatibility(self): def compatibility(self):
return self.payload.compatibility return self.payload.compatibility

@property
def within_data_subdir(self):
return self.payload.within_data_subdir
129 changes: 29 additions & 100 deletions src/python/pants/backend/python/tasks/unpack_wheels.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -5,28 +5,24 @@
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals


import os import os
import re
from builtins import str from builtins import str
from hashlib import sha1 from hashlib import sha1


from future.utils import PY3 from future.utils import PY3
from pex.pex import PEX
from pex.pex_builder import PEXBuilder from pex.pex_builder import PEXBuilder
from pex.platforms import Platform


from pants.backend.native.config.environment import Platform as NativeBackendPlatform
from pants.backend.python.interpreter_cache import PythonInterpreterCache from pants.backend.python.interpreter_cache import PythonInterpreterCache
from pants.backend.python.subsystems.pex_build_util import PexBuilderWrapper from pants.backend.python.subsystems.pex_build_util import PexBuilderWrapper
from pants.backend.python.subsystems.python_setup import PythonSetup from pants.backend.python.subsystems.python_setup import PythonSetup
from pants.backend.python.targets.unpacked_whls import UnpackedWheels from pants.backend.python.targets.unpacked_whls import UnpackedWheels
from pants.base.exceptions import TaskError from pants.base.exceptions import TaskError
from pants.base.fingerprint_strategy import DefaultFingerprintHashingMixin, FingerprintStrategy from pants.base.fingerprint_strategy import DefaultFingerprintHashingMixin, FingerprintStrategy
from pants.fs.archive import ZIP
from pants.task.unpack_remote_sources_base import UnpackRemoteSourcesBase from pants.task.unpack_remote_sources_base import UnpackRemoteSourcesBase
from pants.util.contextutil import temporary_dir, temporary_file from pants.util.contextutil import temporary_dir
from pants.util.dirutil import mergetree, safe_concurrent_creation from pants.util.dirutil import mergetree, safe_concurrent_creation
from pants.util.memo import memoized_classproperty, memoized_method from pants.util.memo import memoized_method
from pants.util.objects import SubclassesOf from pants.util.objects import SubclassesOf
from pants.util.process_handler import subprocess




class UnpackWheelsFingerprintStrategy(DefaultFingerprintHashingMixin, FingerprintStrategy): class UnpackWheelsFingerprintStrategy(DefaultFingerprintHashingMixin, FingerprintStrategy):
Expand Down Expand Up @@ -62,114 +58,47 @@ def subsystem_dependencies(cls):


class _NativeCodeExtractionSetupFailure(Exception): pass class _NativeCodeExtractionSetupFailure(Exception): pass


@staticmethod def _get_matching_wheel(self, pex_path, interpreter, requirements, module_name):
def _exercise_module(pex, expected_module): """Use PexBuilderWrapper to resolve a single wheel from the requirement specs using pex."""
# Ripped from test_resolve_requirements.py. with self.context.new_workunit('extract-native-wheels'):
with temporary_file(binary_mode=False) as f: with safe_concurrent_creation(pex_path) as chroot:
f.write('import {m}; print({m}.__file__)'.format(m=expected_module)) pex_builder = PexBuilderWrapper.Factory.create(
f.close() builder=PEXBuilder(path=chroot, interpreter=interpreter),
proc = pex.run(args=[f.name], blocking=False, log=self.context.log)
stdout=subprocess.PIPE, stderr=subprocess.PIPE) return pex_builder.extract_single_dist_for_current_platform(requirements, module_name)
stdout, stderr = proc.communicate()
return (stdout.decode('utf-8'), stderr.decode('utf-8'))

@classmethod
def _get_wheel_dir(cls, pex, module_name):
"""Get the directory of a specific wheel contained within an unpacked pex."""
stdout_data, stderr_data = cls._exercise_module(pex, module_name)
if stderr_data != '':
raise cls._NativeCodeExtractionSetupFailure(
"Error extracting module '{}' from pex at {}.\nstdout:\n{}\n----\nstderr:\n{}"
.format(module_name, pex.path, stdout_data, stderr_data))

module_path = stdout_data.strip()
wheel_dir = os.path.join(
module_path[0:module_path.find('{sep}.deps{sep}'.format(sep=os.sep))],
'.deps',
)
if not os.path.isdir(wheel_dir):
raise cls._NativeCodeExtractionSetupFailure(
"Wheel dir for module '{}' was not found in path '{}' of pex at '{}'."
.format(module_name, module_path, pex.path))
return wheel_dir

@staticmethod
def _name_and_platform(whl):
# The wheel filename is of the format
# {distribution}-{version}(-{build tag})?-{python tag}-{abi tag}-{platform tag}.whl
# See https://www.python.org/dev/peps/pep-0425/.
# We don't care about the python or abi versions because we expect pex to resolve the
# appropriate versions for the current host.
parts = os.path.splitext(whl)[0].split('-')
return '{}-{}'.format(parts[0], parts[1]), parts[-1]

@memoized_classproperty
def _current_platform_abbreviation(cls):
return NativeBackendPlatform.create().resolve_for_enum_variant({
'darwin': 'macosx',
'linux': 'linux',
})

@classmethod
def _get_matching_wheel_dir(cls, wheel_dir, module_name):
wheels = os.listdir(wheel_dir)

names_and_platforms = {w:cls._name_and_platform(w) for w in wheels}
for whl_filename, (name, platform) in names_and_platforms.items():
if cls._current_platform_abbreviation in platform:
# TODO: this guards against packages which have names that are prefixes of other packages by
# checking if there is a version number beginning -- is there a more canonical way to do
# this?
if re.match(r'^{}\-[0-9]'.format(re.escape(module_name)), name):
return os.path.join(wheel_dir, whl_filename, module_name)

raise cls._NativeCodeExtractionSetupFailure(
"Could not find wheel in dir '{wheel_dir}' matching module name '{module_name}' "
"for current platform '{pex_current_platform}', when looking for platforms containing the "
"substring {cur_platform_abbrev}.\n"
"wheels: {wheels}"
.format(wheel_dir=wheel_dir,
module_name=module_name,
pex_current_platform=Platform.current().platform,
cur_platform_abbrev=cls._current_platform_abbreviation,
wheels=wheels))

def _generate_requirements_pex(self, pex_path, interpreter, requirements):
if not os.path.exists(pex_path):
with self.context.new_workunit('extract-native-wheels'):
with safe_concurrent_creation(pex_path) as chroot:
pex_builder = PexBuilderWrapper.Factory.create(
builder=PEXBuilder(path=chroot, interpreter=interpreter),
log=self.context.log)
pex_builder.add_resolved_requirements(requirements)
pex_builder.freeze()
return PEX(pex_path, interpreter=interpreter)


@memoized_method @memoized_method
def _compatible_interpreter(self, unpacked_whls): def _compatible_interpreter(self, unpacked_whls):
constraints = PythonSetup.global_instance().compatibility_or_constraints(unpacked_whls) constraints = PythonSetup.global_instance().compatibility_or_constraints(unpacked_whls)
allowable_interpreters = PythonInterpreterCache.global_instance().setup(filters=constraints) allowable_interpreters = PythonInterpreterCache.global_instance().setup(filters=constraints)
return min(allowable_interpreters) return min(allowable_interpreters)


class NativeCodeExtractionError(TaskError): pass class WheelUnpackingError(TaskError): pass


def unpack_target(self, unpacked_whls, unpack_dir): def unpack_target(self, unpacked_whls, unpack_dir):
interpreter = self._compatible_interpreter(unpacked_whls) interpreter = self._compatible_interpreter(unpacked_whls)


with temporary_dir() as tmp_dir: with temporary_dir() as resolve_dir,\
# NB: The pex needs to be in a subdirectory for some reason, and pants task caching ensures it temporary_dir() as extract_dir:
# is the only member of this directory, so the dirname doesn't matter.
pex_path = os.path.join(tmp_dir, 'xxx.pex')
try: try:
pex = self._generate_requirements_pex(pex_path, interpreter, matched_dist = self._get_matching_wheel(resolve_dir, interpreter,
unpacked_whls.all_imported_requirements) unpacked_whls.all_imported_requirements,
wheel_dir = self._get_wheel_dir(pex, unpacked_whls.module_name) unpacked_whls.module_name)
matching_wheel_dir = self._get_matching_wheel_dir(wheel_dir, unpacked_whls.module_name) ZIP.extract(matched_dist.location, extract_dir)
if unpacked_whls.within_data_subdir:
data_dir_prefix = '{name}-{version}.data/{subdir}'.format(
name=matched_dist.project_name,
version=matched_dist.version,
subdir=unpacked_whls.within_data_subdir,
)
dist_data_dir = os.path.join(extract_dir, data_dir_prefix)
else:
dist_data_dir = extract_dir
unpack_filter = self.get_unpack_filter(unpacked_whls) unpack_filter = self.get_unpack_filter(unpacked_whls)
# Copy over the module's data files into `unpack_dir`. # Copy over the module's data files into `unpack_dir`.
mergetree(matching_wheel_dir, unpack_dir, file_filter=unpack_filter) mergetree(dist_data_dir, unpack_dir, file_filter=unpack_filter)
except Exception as e: except Exception as e:
raise self.NativeCodeExtractionError( raise self.WheelUnpackingError(
"Error extracting wheel for target {}: {}" "Error extracting wheel for target {}: {}"
.format(unpacked_whls, str(e)), .format(unpacked_whls, str(e)),
e) e)
4 changes: 2 additions & 2 deletions src/python/pants/task/unpack_remote_sources_base.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def _calculate_unpack_filter(cls, includes=None, excludes=None, spec=None):
field_name='include_patterns', field_name='include_patterns',
spec=spec) spec=spec)
logger.debug('include_patterns: {}' logger.debug('include_patterns: {}'
.format(p.pattern for p in include_patterns)) .format(list(p.pattern for p in include_patterns)))
exclude_patterns = cls.compile_patterns(excludes or [], exclude_patterns = cls.compile_patterns(excludes or [],
field_name='exclude_patterns', field_name='exclude_patterns',
spec=spec) spec=spec)
logger.debug('exclude_patterns: {}' logger.debug('exclude_patterns: {}'
.format(p.pattern for p in exclude_patterns)) .format(list(p.pattern for p in exclude_patterns)))
return lambda f: cls._file_filter(f, include_patterns, exclude_patterns) return lambda f: cls._file_filter(f, include_patterns, exclude_patterns)


@classmethod @classmethod
Expand Down
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_simple(self):


def test_bad_libraries_ref(self): def test_bad_libraries_ref(self):
self.make_target(':right-type', JarLibrary, jars=[JarDependency('foo', 'bar', '123')]) self.make_target(':right-type', JarLibrary, jars=[JarDependency('foo', 'bar', '123')])
# Making a target which is not a jar library, which causes an error.
self.make_target(':wrong-type', UnpackedJars, libraries=[':right-type']) self.make_target(':wrong-type', UnpackedJars, libraries=[':right-type'])
target = self.make_target(':foo', UnpackedJars, libraries=[':wrong-type']) target = self.make_target(':foo', UnpackedJars, libraries=[':wrong-type'])
with self.assertRaises(ImportJarsMixin.WrongTargetTypeError): with self.assertRaises(ImportJarsMixin.WrongTargetTypeError):
Expand Down Expand Up @@ -85,6 +86,6 @@ def assert_dep(dep, org, name, rev):
unpacked_jar_deps = unpacked_lib.all_imported_jar_deps unpacked_jar_deps = unpacked_lib.all_imported_jar_deps


self.assertEqual(3, len(unpacked_jar_deps)) self.assertEqual(3, len(unpacked_jar_deps))
assert_dep(lib1.jar_dependencies[0], 'testOrg1', 'testName1', '123') assert_dep(unpacked_jar_deps[0], 'testOrg1', 'testName1', '123')
assert_dep(lib2.jar_dependencies[0], 'testOrg2', 'testName2', '456') assert_dep(unpacked_jar_deps[1], 'testOrg2', 'testName2', '456')
assert_dep(lib2.jar_dependencies[1], 'testOrg3', 'testName3', '789') assert_dep(unpacked_jar_deps[2], 'testOrg3', 'testName3', '789')
40 changes: 0 additions & 40 deletions tests/python/pants_test/backend/jvm/tasks/test_unpack_jars.py
Original file line number Original file line Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@


import functools import functools
import os import os
import unittest
from contextlib import contextmanager from contextlib import contextmanager


from pants.backend.jvm.targets.jar_library import JarLibrary from pants.backend.jvm.targets.jar_library import JarLibrary
Expand Down Expand Up @@ -36,45 +35,6 @@ def sample_jarfile(self, name):
proto_jarfile.writestr('a/b/c/{}.proto'.format(name), 'message Msg {}') proto_jarfile.writestr('a/b/c/{}.proto'.format(name), 'message Msg {}')
yield jar_name yield jar_name


def test_invalid_pattern(self):
with self.assertRaises(UnpackJars.InvalidPatternError):
UnpackJars.compile_patterns([45])

@staticmethod
def _run_filter(filename, include_patterns=None, exclude_patterns=None):
return UnpackJars._file_filter(
filename,
UnpackJars.compile_patterns(include_patterns or []),
UnpackJars.compile_patterns(exclude_patterns or []))

def test_file_filter(self):
# If no patterns are specified, everything goes through
self.assertTrue(self._run_filter("foo/bar.java"))

self.assertTrue(self._run_filter("foo/bar.java", include_patterns=["**/*.java"]))
self.assertTrue(self._run_filter("bar.java", include_patterns=["**/*.java"]))
self.assertTrue(self._run_filter("bar.java", include_patterns=["**/*.java", "*.java"]))
self.assertFalse(self._run_filter("foo/bar.java", exclude_patterns=["**/bar.*"]))
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=["**/*/java"],
exclude_patterns=["**/bar.*"]))

# exclude patterns should be computed before include patterns
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=["foo/*.java"],
exclude_patterns=["foo/b*.java"]))
self.assertTrue(self._run_filter("foo/bar.java",
include_patterns=["foo/*.java"],
exclude_patterns=["foo/x*.java"]))

@unittest.expectedFailure
def test_problematic_cases(self):
"""These should pass, but don't"""
# See https://github.com/twitter/commons/issues/380. 'foo*bar' doesn't match 'foobar'
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=['foo/*.java'],
exclude_patterns=['foo/bar*.java']))

def _make_jar_library(self, coord): def _make_jar_library(self, coord):
return self.make_target(spec='unpack/jars:foo-jars', return self.make_target(spec='unpack/jars:foo-jars',
target_type=JarLibrary, target_type=JarLibrary,
Expand Down
Loading

0 comments on commit 2c9c338

Please sign in to comment.