Skip to content

Commit

Permalink
Allow intransitive unpacking of unpacked_jars targets.
Browse files Browse the repository at this point in the history
Adds an `intransitive` property to the target.

Also improves testing of the unpack_jars task: Removes
an old test that really just tested fingerprinting and
invalidation. Replaces it with one that actually checks
task result output.
  • Loading branch information
benjyw committed Jan 28, 2018
1 parent 74b324a commit 5f61192
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 100 deletions.
2 changes: 1 addition & 1 deletion src/python/pants/backend/jvm/targets/import_jars_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def gen_specs():
@memoized_property
def imported_jars(self):
""":returns: the string specs of JarDependencies referenced by imported_jar_library_specs
:rtype: list of string
:rtype: list of JarDependency
"""
return JarLibrary.to_jar_dependencies(self.address,
self.imported_jar_library_specs(payload=self.payload),
Expand Down
7 changes: 4 additions & 3 deletions src/python/pants/backend/jvm/targets/unpacked_jars.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,21 @@ class ExpectedLibrariesError(Exception):
pass

def __init__(self, payload=None, libraries=None, include_patterns=None, exclude_patterns=None,
**kwargs):
intransitive=False, **kwargs):
"""
:param libraries: List of addresses of `jar_library <#jar_library>`_
targets which contain .proto definitions.
:param list libraries: addresses of jar_library targets that specify the jars you want to unpack
:param list include_patterns: fileset patterns to include from the archive
:param list exclude_patterns: fileset patterns to exclude from the archive. Exclude patterns
are processed before include_patterns.
:param bool intransitive: Whether to unpack all resolved dependencies of the jars, or just
the jars themselves.
"""
payload = payload or Payload()
payload.add_fields({
'library_specs': PrimitiveField(libraries or ()),
'include_patterns' : PrimitiveField(include_patterns or ()),
'exclude_patterns' : PrimitiveField(exclude_patterns or ()),
'intransitive': PrimitiveField(intransitive)
})
super(UnpackedJars, self).__init__(payload=payload, **kwargs)

Expand Down
2 changes: 1 addition & 1 deletion src/python/pants/backend/jvm/tasks/ivy_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


class IvyImports(IvyTaskMixin, NailgunTask):
"""Resolves all jar files for the import_jar_libraries property on all `ImportJarsMixin` targets.
"""Resolves jar files for imported_jar_libraries on `ImportJarsMixin` targets.
One use case is for JavaProtobufLibrary, which includes imports for jars containing .proto files.
"""
Expand Down
34 changes: 14 additions & 20 deletions src/python/pants/backend/jvm/tasks/unpack_jars.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def _file_filter(cls, filename, include_patterns, exclude_patterns):
return True

@classmethod
def _compile_patterns(cls, patterns, field_name="Unknown", spec="Unknown"):
def compile_patterns(cls, patterns, field_name="Unknown", spec="Unknown"):
compiled_patterns = []
for p in patterns:
try:
Expand All @@ -100,12 +100,12 @@ def calculate_unpack_filter(cls, includes=None, excludes=None, spec=None):
:param list includes: List of include patterns to pass to _file_filter.
:param list excludes: List of exclude patterns to pass to _file_filter.
"""
include_patterns = cls._compile_patterns(includes or [],
field_name='include_patterns',
spec=spec)
exclude_patterns = cls._compile_patterns(excludes or [],
field_name='exclude_patterns',
spec=spec)
include_patterns = cls.compile_patterns(includes or [],
field_name='include_patterns',
spec=spec)
exclude_patterns = cls.compile_patterns(excludes or [],
field_name='exclude_patterns',
spec=spec)
return lambda f: cls._file_filter(f, include_patterns, exclude_patterns)

# TODO(mateor) move unpack code that isn't jar-specific to fs.archive or an Unpack base class.
Expand All @@ -130,28 +130,25 @@ def _unpack(self, unpacked_jars):
if not os.path.exists(unpack_dir):
os.makedirs(unpack_dir)

direct_coords = {jar.coordinate for jar in unpacked_jars.imported_jars}
unpack_filter = self.get_unpack_filter(unpacked_jars)
jar_import_products = self.context.products.get_data(JarImportProducts)
for coordinate, jar_path in jar_import_products.imports(unpacked_jars):
self.context.log.debug('Unpacking jar {coordinate} from {jar_path} to {unpack_dir}.'
.format(coordinate=coordinate,
jar_path=jar_path,
unpack_dir=unpack_dir))
ZIP.extract(jar_path, unpack_dir, filter_func=unpack_filter)
if not unpacked_jars.payload.intransitive or coordinate in direct_coords:
self.context.log.info('Unpacking jar {coordinate} from {jar_path} to {unpack_dir}.'.format(
coordinate=coordinate, jar_path=jar_path, unpack_dir=unpack_dir))
ZIP.extract(jar_path, unpack_dir, filter_func=unpack_filter)

def execute(self):
addresses = [target.address for target in self.context.targets()]
closure = self.context.build_graph.transitive_subgraph_of_addresses(addresses)
unpacked_jars_list = [t for t in closure if isinstance(t, UnpackedJars)]

unpacked_targets = []
with self.invalidated(unpacked_jars_list,
fingerprint_strategy=UnpackJarsFingerprintStrategy(),
invalidate_dependents=True) as invalidation_check:
if invalidation_check.invalid_vts:
unpacked_targets.extend([vt.target for vt in invalidation_check.invalid_vts])
for target in unpacked_targets:
self._unpack(target)
for vt in invalidation_check.invalid_vts:
self._unpack(vt.target)

for unpacked_jars_target in unpacked_jars_list:
unpack_dir = self._unpack_dir(unpacked_jars_target)
Expand All @@ -167,6 +164,3 @@ def execute(self):
rel_unpack_dir = os.path.relpath(unpack_dir, get_buildroot())
unpacked_sources_product = self.context.products.get_data('unpacked_archives', lambda: {})
unpacked_sources_product[unpacked_jars_target] = [found_files, rel_unpack_dir]

# Returning the list of unpacked targets for testing purposes
return unpacked_targets
126 changes: 51 additions & 75 deletions tests/python/pants_test/backend/jvm/tasks/test_unpack_jars.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from pants.java.jar.jar_dependency import JarDependency
from pants.java.jar.jar_dependency_utils import M2Coordinate
from pants.util.contextutil import open_zip, temporary_dir
from pants.util.dirutil import safe_walk
from pants_test.tasks.task_test_base import TaskTestBase


Expand All @@ -28,24 +27,24 @@ def task_type(cls):
return UnpackJars

@contextmanager
def sample_jarfile(self):
"""Create a jar file with a/b/c/data.txt and a/b/c/foo.proto"""
def sample_jarfile(self, name):
with temporary_dir() as temp_dir:
jar_name = os.path.join(temp_dir, 'foo.jar')
jar_name = os.path.join(temp_dir, '{}.jar'.format(name))
with open_zip(jar_name, 'w') as proto_jarfile:
proto_jarfile.writestr('a/b/c/data.txt', 'Foo text')
proto_jarfile.writestr('a/b/c/foo.proto', 'message Foo {}')
proto_jarfile.writestr('a/b/c/{}.txt'.format(name), 'Some text')
proto_jarfile.writestr('a/b/c/{}.proto'.format(name), 'message Msg {}')
yield jar_name

def test_invalid_pattern(self):
with self.assertRaises(UnpackJars.InvalidPatternError):
UnpackJars._compile_patterns([45])
UnpackJars.compile_patterns([45])

def _run_filter(self, filename, include_patterns=None, exclude_patterns=None):
@staticmethod
def _run_filter(filename, include_patterns=None, exclude_patterns=None):
return UnpackJars._file_filter(
filename,
UnpackJars._compile_patterns(include_patterns or []),
UnpackJars._compile_patterns(exclude_patterns or []))
UnpackJars.compile_patterns(include_patterns or []),
UnpackJars.compile_patterns(exclude_patterns or []))

def test_file_filter(self):
# If no patterns are specified, everything goes through
Expand Down Expand Up @@ -81,90 +80,67 @@ def _make_jar_library(self, coord):
jars=[JarDependency(org=coord.org, name=coord.name, rev=coord.rev,
url='file:///foo.jar')])

def _make_unpacked_jar(self, coord, include_patterns):
bar = self._make_jar_library(coord)
def _make_unpacked_jar(self, coord, include_patterns, intransitive=False):
jarlib = self._make_jar_library(coord)
return self.make_target(spec='unpack:foo',
target_type=UnpackedJars,
libraries=[bar.address.spec],
include_patterns=include_patterns)

def _make_coord(self, rev):
return M2Coordinate(org='com.example', name='bar', rev=rev)
libraries=[jarlib.address.spec],
include_patterns=include_patterns,
intransitive=intransitive)

def test_unpack_jar_fingerprint_strategy(self):
fingerprint_strategy = UnpackJarsFingerprintStrategy()

make_unpacked_jar = functools.partial(self._make_unpacked_jar, include_patterns=['bar'])
rev1 = self._make_coord(rev='0.0.1')
rev1 = M2Coordinate(org='com.example', name='bar', rev='0.0.1')
target = make_unpacked_jar(rev1)
fingerprint1 = fingerprint_strategy.compute_fingerprint(target)

# Now, replace the build file with a different version
# Now, replace the build file with a different version.
self.reset_build_graph()
target = make_unpacked_jar(self._make_coord(rev='0.0.2'))
target = make_unpacked_jar(M2Coordinate(org='com.example', name='bar', rev='0.0.2'))
fingerprint2 = fingerprint_strategy.compute_fingerprint(target)
self.assertNotEqual(fingerprint1, fingerprint2)

# Go back to the original library
# Go back to the original library.
self.reset_build_graph()
target = make_unpacked_jar(rev1)
fingerprint3 = fingerprint_strategy.compute_fingerprint(target)

self.assertEqual(fingerprint1, fingerprint3)

def _add_dummy_product(self, unpack_task, foo_target, jar_filename, coord):
jar_import_products = unpack_task.context.products.get_data(JarImportProducts,
init_func=JarImportProducts)
@staticmethod
def _add_dummy_product(context, foo_target, jar_filename, coord):
jar_import_products = context.products.get_data(JarImportProducts, init_func=JarImportProducts)
jar_import_products.imported(foo_target, coord, jar_filename)

def test_incremental(self):
make_unpacked_jar = functools.partial(self._make_unpacked_jar,
include_patterns=['a/b/c/*.proto'])

with self.sample_jarfile() as jar_filename:
rev1 = self._make_coord(rev='0.0.1')
foo_target = make_unpacked_jar(rev1)

# The first time through, the target should be unpacked.
unpack_task = self.create_task(self.context(target_roots=[foo_target]))
self._add_dummy_product(unpack_task, foo_target, jar_filename, rev1)
unpacked_targets = unpack_task.execute()

self.assertEquals([foo_target], unpacked_targets)
unpack_dir = unpack_task._unpack_dir(foo_target)
files = []
for _, dirname, filenames in safe_walk(unpack_dir):
files += filenames
self.assertEquals(['foo.proto'], files)

# Calling the task a second time should not need to unpack any targets
unpack_task = self.create_task(self.context(target_roots=[foo_target]))
self._add_dummy_product(unpack_task, foo_target, jar_filename, rev1)
unpacked_targets = unpack_task.execute()

self.assertEquals([], unpacked_targets)

# Change the library version and the target should be unpacked again.
self.reset_build_graph() # Forget about the old definition of the unpack/jars:foo-jar target
rev2 = self._make_coord(rev='0.0.2')
foo_target = make_unpacked_jar(rev2)

unpack_task = self.create_task(self.context(target_roots=[foo_target]))
self._add_dummy_product(unpack_task, foo_target, jar_filename, rev2)
unpacked_targets = unpack_task.execute()

self.assertEquals([foo_target], unpacked_targets)

# Change the include pattern and the target should be unpacked again
self.reset_build_graph() # Forget about the old definition of the unpack/jars:foo-jar target

make_unpacked_jar = functools.partial(self._make_unpacked_jar,
include_patterns=['a/b/c/foo.proto'])
foo_target = make_unpacked_jar(rev2)
unpack_task = self.create_task(self.context(target_roots=[foo_target]))
self._add_dummy_product(unpack_task, foo_target, jar_filename, rev2)
unpacked_targets = unpack_task.execute()

self.assertEquals([foo_target], unpacked_targets)

# TODO(Eric Ayers) Check the 'unpacked_archives' product
def _do_test_products(self, intransitive):
self.maxDiff = None
with self.sample_jarfile('foo') as foo_jar:
with self.sample_jarfile('bar') as bar_jar:
foo_coords = M2Coordinate(org='com.example', name='foo', rev='0.0.1')
bar_coords = M2Coordinate(org='com.example', name='bar', rev='0.0.7')
unpacked_jar_tgt = self._make_unpacked_jar(
foo_coords, include_patterns=['a/b/c/*.proto'], intransitive=intransitive)

context = self.context(target_roots=[unpacked_jar_tgt])
unpack_task = self.create_task(context)
self._add_dummy_product(context, unpacked_jar_tgt, foo_jar, foo_coords)
# We add jar_bar as a product against foo_tgt, to simulate it being an
# externally-resolved dependency of jar_foo.
self._add_dummy_product(context, unpacked_jar_tgt, bar_jar, bar_coords)
unpack_task.execute()

expected= ['a/b/c/foo.proto']
if not intransitive:
expected.append('a/b/c/bar.proto')
self.assertEquals(
{unpacked_jar_tgt:
[expected, '.pants.d/pants_backend_jvm_tasks_unpack_jars_UnpackJars/unpack.foo']},
context.products.get_data('unpacked_archives', dict))

def test_transitive(self):
self._do_test_products(intransitive=False)

def test_intransitive(self):
self._do_test_products(intransitive=True)

0 comments on commit 5f61192

Please sign in to comment.