Skip to content
Permalink
Browse files

move wheel extraction for UnpackWheels into PexBuilderWrapper and res…

…olve for a single platform only (#7289)

### Problem

Resolves #7245. We are seeing wheels getting unpacked for the wrong platform nondeterministically in our internal repo, where we have multiple platforms in `python-setup.platforms`.

### Solution

- Resolve wheels for `UnpackWheels` for the current platform only.
- Ensure that there is exactly a single wheel resolved for the desired distribution name.

### Result

The complex and unnecessary dist resolution / location process in `UnpackWheels` is a lot smoother, we resolve for only a single platform, and we check to ensure only a single dist can be located from the resolved requirements.
  • Loading branch information...
cosmicexplorer committed Mar 1, 2019
1 parent 3c5c70f commit 2c9c338cd721387844d1352c1b0bf77b35c33f0c
@@ -23,4 +23,5 @@ unpacked_whls(
'include/**/*',
'./*.so',
],
within_data_subdir='purelib/tensorflow',
)
@@ -26,6 +26,7 @@
from pants.base.exceptions import TaskError
from pants.build_graph.files import Files
from pants.subsystem.subsystem import Subsystem
from pants.util.collections import assert_single_element


def is_python_target(tgt):
@@ -148,16 +149,33 @@ def add_requirement_libs_from(self, req_libs, platforms=None):
reqs = [req for req_lib in req_libs for req in req_lib.requirements]
self.add_resolved_requirements(reqs, platforms=platforms)

def add_resolved_requirements(self, reqs, platforms=None):
"""Multi-platform dependency resolution for PEX files.
class SingleDistExtractionError(Exception): pass

:param builder: Dump the requirements into this builder.
:param interpreter: The :class:`PythonInterpreter` to resolve requirements for.
:param reqs: A list of :class:`PythonRequirement` to resolve.
:param log: Use this logger.
:param platforms: A list of :class:`Platform`s to resolve requirements for.
Defaults to the platforms specified by PythonSetup.
def extract_single_dist_for_current_platform(self, reqs, dist_key):
"""Resolve a specific distribution from a set of requirements matching the current platform.
:param list reqs: A list of :class:`PythonRequirement` to resolve.
:param str dist_key: The value of `distribution.key` to match for a `distribution` from the
resolved requirements.
:return: The single :class:`pkg_resources.Distribution` matching `dist_key`.
:raises: :class:`self.SingleDistExtractionError` if no dists or multiple dists matched the given
`dist_key`.
"""
distributions = self._resolve_distributions_by_platform(reqs, platforms=['current'])
try:
matched_dist = assert_single_element(list(
dist
for _, dists in distributions.items()
for dist in dists
if dist.key == dist_key
))
except (StopIteration, ValueError) as e:
raise self.SingleDistExtractionError(
"Exactly one dist was expected to match name {} in requirements {}: {}"
.format(dist_key, reqs, e))
return matched_dist

def _resolve_distributions_by_platform(self, reqs, platforms):
deduped_reqs = OrderedSet(reqs)
find_links = OrderedSet()
for req in deduped_reqs:
@@ -169,6 +187,19 @@ def add_resolved_requirements(self, reqs, platforms=None):
# Resolve the requirements into distributions.
distributions = self._resolve_multi(self._builder.interpreter, deduped_reqs, platforms,
find_links)
return distributions

def add_resolved_requirements(self, reqs, platforms=None):
"""Multi-platform dependency resolution for PEX files.
:param builder: Dump the requirements into this builder.
:param interpreter: The :class:`PythonInterpreter` to resolve requirements for.
:param reqs: A list of :class:`PythonRequirement` to resolve.
:param log: Use this logger.
:param platforms: A list of :class:`Platform`s to resolve requirements for.
Defaults to the platforms specified by PythonSetup.
"""
distributions = self._resolve_distributions_by_platform(reqs, platforms=platforms)
locations = set()
for platform, dists in distributions.items():
for dist in dists:
@@ -20,6 +20,8 @@
class UnpackedWheels(ImportWheelsMixin, Target):
"""A set of sources extracted from JAR files.
NB: Currently, wheels are always resolved for the 'current' platform.
:API: public
"""

@@ -34,8 +36,10 @@ class ExpectedLibrariesError(Exception):
"""Thrown when the target has no libraries defined."""
pass

# TODO: consider introducing some form of source roots instead of the manual `within_data_subdir`
# kwarg!
def __init__(self, module_name, libraries=None, include_patterns=None, exclude_patterns=None,
compatibility=None, payload=None, **kwargs):
compatibility=None, within_data_subdir=None, payload=None, **kwargs):
"""
:param str module_name: The name of the specific python module containing headers and/or
libraries to extract (e.g. 'tensorflow').
@@ -47,6 +51,14 @@ def __init__(self, module_name, libraries=None, include_patterns=None, exclude_p
:param compatibility: Python interpreter constraints used to create the pex for the requirement
target. If unset, the default interpreter constraints are used. This
argument is unnecessary unless the native code depends on libpython.
:param str within_data_subdir: If provided, descend into '<name>-<version>.data/<subdir>' when
matching `include_patterns`. For python wheels which declare any
non-code data, this is usually needed to extract that without
manually specifying the relative path, including the package
version. For example, when `data_files` is used in a setup.py,
`within_data_subdir='data'` will allow specifying
`include_patterns` matching exactly what is specified in the
setup.py.
"""
payload = payload or Payload()
payload.add_fields({
@@ -55,7 +67,9 @@ def __init__(self, module_name, libraries=None, include_patterns=None, exclude_p
'include_patterns' : PrimitiveField(include_patterns or ()),
'exclude_patterns' : PrimitiveField(exclude_patterns or ()),
'compatibility': PrimitiveField(maybe_list(compatibility or ())),
'within_data_subdir': PrimitiveField(within_data_subdir),
# TODO: consider supporting transitive deps like UnpackedJars!
# TODO: consider supporting `platforms` as in PythonBinary!
})
super(UnpackedWheels, self).__init__(payload=payload, **kwargs)

@@ -70,3 +84,7 @@ def module_name(self):
@property
def compatibility(self):
return self.payload.compatibility

@property
def within_data_subdir(self):
return self.payload.within_data_subdir
@@ -5,28 +5,24 @@
from __future__ import absolute_import, division, print_function, unicode_literals

import os
import re
from builtins import str
from hashlib import sha1

from future.utils import PY3
from pex.pex import PEX
from pex.pex_builder import PEXBuilder
from pex.platforms import Platform

from pants.backend.native.config.environment import Platform as NativeBackendPlatform
from pants.backend.python.interpreter_cache import PythonInterpreterCache
from pants.backend.python.subsystems.pex_build_util import PexBuilderWrapper
from pants.backend.python.subsystems.python_setup import PythonSetup
from pants.backend.python.targets.unpacked_whls import UnpackedWheels
from pants.base.exceptions import TaskError
from pants.base.fingerprint_strategy import DefaultFingerprintHashingMixin, FingerprintStrategy
from pants.fs.archive import ZIP
from pants.task.unpack_remote_sources_base import UnpackRemoteSourcesBase
from pants.util.contextutil import temporary_dir, temporary_file
from pants.util.contextutil import temporary_dir
from pants.util.dirutil import mergetree, safe_concurrent_creation
from pants.util.memo import memoized_classproperty, memoized_method
from pants.util.memo import memoized_method
from pants.util.objects import SubclassesOf
from pants.util.process_handler import subprocess


class UnpackWheelsFingerprintStrategy(DefaultFingerprintHashingMixin, FingerprintStrategy):
@@ -62,114 +58,47 @@ def subsystem_dependencies(cls):

class _NativeCodeExtractionSetupFailure(Exception): pass

@staticmethod
def _exercise_module(pex, expected_module):
# Ripped from test_resolve_requirements.py.
with temporary_file(binary_mode=False) as f:
f.write('import {m}; print({m}.__file__)'.format(m=expected_module))
f.close()
proc = pex.run(args=[f.name], blocking=False,
stdout=subprocess.PIPE, stderr=subprocess.PIPE)
stdout, stderr = proc.communicate()
return (stdout.decode('utf-8'), stderr.decode('utf-8'))

@classmethod
def _get_wheel_dir(cls, pex, module_name):
"""Get the directory of a specific wheel contained within an unpacked pex."""
stdout_data, stderr_data = cls._exercise_module(pex, module_name)
if stderr_data != '':
raise cls._NativeCodeExtractionSetupFailure(
"Error extracting module '{}' from pex at {}.\nstdout:\n{}\n----\nstderr:\n{}"
.format(module_name, pex.path, stdout_data, stderr_data))

module_path = stdout_data.strip()
wheel_dir = os.path.join(
module_path[0:module_path.find('{sep}.deps{sep}'.format(sep=os.sep))],
'.deps',
)
if not os.path.isdir(wheel_dir):
raise cls._NativeCodeExtractionSetupFailure(
"Wheel dir for module '{}' was not found in path '{}' of pex at '{}'."
.format(module_name, module_path, pex.path))
return wheel_dir

@staticmethod
def _name_and_platform(whl):
# The wheel filename is of the format
# {distribution}-{version}(-{build tag})?-{python tag}-{abi tag}-{platform tag}.whl
# See https://www.python.org/dev/peps/pep-0425/.
# We don't care about the python or abi versions because we expect pex to resolve the
# appropriate versions for the current host.
parts = os.path.splitext(whl)[0].split('-')
return '{}-{}'.format(parts[0], parts[1]), parts[-1]

@memoized_classproperty
def _current_platform_abbreviation(cls):
return NativeBackendPlatform.create().resolve_for_enum_variant({
'darwin': 'macosx',
'linux': 'linux',
})

@classmethod
def _get_matching_wheel_dir(cls, wheel_dir, module_name):
wheels = os.listdir(wheel_dir)

names_and_platforms = {w:cls._name_and_platform(w) for w in wheels}
for whl_filename, (name, platform) in names_and_platforms.items():
if cls._current_platform_abbreviation in platform:
# TODO: this guards against packages which have names that are prefixes of other packages by
# checking if there is a version number beginning -- is there a more canonical way to do
# this?
if re.match(r'^{}\-[0-9]'.format(re.escape(module_name)), name):
return os.path.join(wheel_dir, whl_filename, module_name)

raise cls._NativeCodeExtractionSetupFailure(
"Could not find wheel in dir '{wheel_dir}' matching module name '{module_name}' "
"for current platform '{pex_current_platform}', when looking for platforms containing the "
"substring {cur_platform_abbrev}.\n"
"wheels: {wheels}"
.format(wheel_dir=wheel_dir,
module_name=module_name,
pex_current_platform=Platform.current().platform,
cur_platform_abbrev=cls._current_platform_abbreviation,
wheels=wheels))

def _generate_requirements_pex(self, pex_path, interpreter, requirements):
if not os.path.exists(pex_path):
with self.context.new_workunit('extract-native-wheels'):
with safe_concurrent_creation(pex_path) as chroot:
pex_builder = PexBuilderWrapper.Factory.create(
builder=PEXBuilder(path=chroot, interpreter=interpreter),
log=self.context.log)
pex_builder.add_resolved_requirements(requirements)
pex_builder.freeze()
return PEX(pex_path, interpreter=interpreter)
def _get_matching_wheel(self, pex_path, interpreter, requirements, module_name):
"""Use PexBuilderWrapper to resolve a single wheel from the requirement specs using pex."""
with self.context.new_workunit('extract-native-wheels'):
with safe_concurrent_creation(pex_path) as chroot:
pex_builder = PexBuilderWrapper.Factory.create(
builder=PEXBuilder(path=chroot, interpreter=interpreter),
log=self.context.log)
return pex_builder.extract_single_dist_for_current_platform(requirements, module_name)

@memoized_method
def _compatible_interpreter(self, unpacked_whls):
constraints = PythonSetup.global_instance().compatibility_or_constraints(unpacked_whls)
allowable_interpreters = PythonInterpreterCache.global_instance().setup(filters=constraints)
return min(allowable_interpreters)

class NativeCodeExtractionError(TaskError): pass
class WheelUnpackingError(TaskError): pass

def unpack_target(self, unpacked_whls, unpack_dir):
interpreter = self._compatible_interpreter(unpacked_whls)

with temporary_dir() as tmp_dir:
# NB: The pex needs to be in a subdirectory for some reason, and pants task caching ensures it
# is the only member of this directory, so the dirname doesn't matter.
pex_path = os.path.join(tmp_dir, 'xxx.pex')
with temporary_dir() as resolve_dir,\
temporary_dir() as extract_dir:
try:
pex = self._generate_requirements_pex(pex_path, interpreter,
unpacked_whls.all_imported_requirements)
wheel_dir = self._get_wheel_dir(pex, unpacked_whls.module_name)
matching_wheel_dir = self._get_matching_wheel_dir(wheel_dir, unpacked_whls.module_name)
matched_dist = self._get_matching_wheel(resolve_dir, interpreter,
unpacked_whls.all_imported_requirements,
unpacked_whls.module_name)
ZIP.extract(matched_dist.location, extract_dir)
if unpacked_whls.within_data_subdir:
data_dir_prefix = '{name}-{version}.data/{subdir}'.format(
name=matched_dist.project_name,
version=matched_dist.version,
subdir=unpacked_whls.within_data_subdir,
)
dist_data_dir = os.path.join(extract_dir, data_dir_prefix)
else:
dist_data_dir = extract_dir
unpack_filter = self.get_unpack_filter(unpacked_whls)
# Copy over the module's data files into `unpack_dir`.
mergetree(matching_wheel_dir, unpack_dir, file_filter=unpack_filter)
mergetree(dist_data_dir, unpack_dir, file_filter=unpack_filter)
except Exception as e:
raise self.NativeCodeExtractionError(
raise self.WheelUnpackingError(
"Error extracting wheel for target {}: {}"
.format(unpacked_whls, str(e)),
e)
@@ -100,12 +100,12 @@ def _calculate_unpack_filter(cls, includes=None, excludes=None, spec=None):
field_name='include_patterns',
spec=spec)
logger.debug('include_patterns: {}'
.format(p.pattern for p in include_patterns))
.format(list(p.pattern for p in include_patterns)))
exclude_patterns = cls.compile_patterns(excludes or [],
field_name='exclude_patterns',
spec=spec)
logger.debug('exclude_patterns: {}'
.format(p.pattern for p in exclude_patterns))
.format(list(p.pattern for p in exclude_patterns)))
return lambda f: cls._file_filter(f, include_patterns, exclude_patterns)

@classmethod
@@ -42,6 +42,7 @@ def test_simple(self):

def test_bad_libraries_ref(self):
self.make_target(':right-type', JarLibrary, jars=[JarDependency('foo', 'bar', '123')])
# Making a target which is not a jar library, which causes an error.
self.make_target(':wrong-type', UnpackedJars, libraries=[':right-type'])
target = self.make_target(':foo', UnpackedJars, libraries=[':wrong-type'])
with self.assertRaises(ImportJarsMixin.WrongTargetTypeError):
@@ -85,6 +86,6 @@ def assert_dep(dep, org, name, rev):
unpacked_jar_deps = unpacked_lib.all_imported_jar_deps

self.assertEqual(3, len(unpacked_jar_deps))
assert_dep(lib1.jar_dependencies[0], 'testOrg1', 'testName1', '123')
assert_dep(lib2.jar_dependencies[0], 'testOrg2', 'testName2', '456')
assert_dep(lib2.jar_dependencies[1], 'testOrg3', 'testName3', '789')
assert_dep(unpacked_jar_deps[0], 'testOrg1', 'testName1', '123')
assert_dep(unpacked_jar_deps[1], 'testOrg2', 'testName2', '456')
assert_dep(unpacked_jar_deps[2], 'testOrg3', 'testName3', '789')
@@ -6,7 +6,6 @@

import functools
import os
import unittest
from contextlib import contextmanager

from pants.backend.jvm.targets.jar_library import JarLibrary
@@ -36,45 +35,6 @@ def sample_jarfile(self, name):
proto_jarfile.writestr('a/b/c/{}.proto'.format(name), 'message Msg {}')
yield jar_name

def test_invalid_pattern(self):
with self.assertRaises(UnpackJars.InvalidPatternError):
UnpackJars.compile_patterns([45])

@staticmethod
def _run_filter(filename, include_patterns=None, exclude_patterns=None):
return UnpackJars._file_filter(
filename,
UnpackJars.compile_patterns(include_patterns or []),
UnpackJars.compile_patterns(exclude_patterns or []))

def test_file_filter(self):
# If no patterns are specified, everything goes through
self.assertTrue(self._run_filter("foo/bar.java"))

self.assertTrue(self._run_filter("foo/bar.java", include_patterns=["**/*.java"]))
self.assertTrue(self._run_filter("bar.java", include_patterns=["**/*.java"]))
self.assertTrue(self._run_filter("bar.java", include_patterns=["**/*.java", "*.java"]))
self.assertFalse(self._run_filter("foo/bar.java", exclude_patterns=["**/bar.*"]))
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=["**/*/java"],
exclude_patterns=["**/bar.*"]))

# exclude patterns should be computed before include patterns
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=["foo/*.java"],
exclude_patterns=["foo/b*.java"]))
self.assertTrue(self._run_filter("foo/bar.java",
include_patterns=["foo/*.java"],
exclude_patterns=["foo/x*.java"]))

@unittest.expectedFailure
def test_problematic_cases(self):
"""These should pass, but don't"""
# See https://github.com/twitter/commons/issues/380. 'foo*bar' doesn't match 'foobar'
self.assertFalse(self._run_filter("foo/bar.java",
include_patterns=['foo/*.java'],
exclude_patterns=['foo/bar*.java']))

def _make_jar_library(self, coord):
return self.make_target(spec='unpack/jars:foo-jars',
target_type=JarLibrary,

0 comments on commit 2c9c338

Please sign in to comment.
You can’t perform that action at this time.