Skip to content

Commit

Permalink
Add traverse method to SourceRootTrie
Browse files Browse the repository at this point in the history
This will return a set of named tuples of all valid paths (with filesystem
globbing included), and the languages/categories each path applies to, that a
given SourceRootTrie instance describes. This will be useful for
implementing a v2 engine-compatible way of getting a list of all source
files, but is a standalone change in and of itself.
  • Loading branch information
gshuflin committed Aug 15, 2019
1 parent aa47d93 commit 4f0d075
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 1 deletion.
25 changes: 25 additions & 0 deletions src/python/pants/source/source_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import os
from collections import namedtuple
from typing import Set

from pants.base.project_tree_factory import get_project_tree
from pants.subsystem.subsystem import Subsystem
Expand All @@ -19,6 +20,10 @@ class SourceRootCategories:

SourceRoot = namedtuple('_SourceRoot', ['path', 'langs', 'category'])

# Named tuple of path/langs/category, where the langs are deliberately not canonicalized in order to match up
# with the actual directory names on disk
UncanonicalizedSourceRoot = namedtuple('UncanonicalizedSourceRoot', ['path', 'langs', 'category'])


class SourceRootFactory:
"""Creates source roots that respect language canonicalizations."""
Expand Down Expand Up @@ -361,6 +366,25 @@ def _do_add_pattern(self, pattern, langs, category):
node.category = category
node.is_terminal = True

def traverse(self) -> Set[UncanonicalizedSourceRoot]:
uncanonicalized_source_roots = set()
lang_canonicalizations = self._source_root_factory._lang_canonicalizations
all_lang_names = tuple(lang_canonicalizations.keys())

def traverse_helper(node, path_components):
for name in node.children:
child = node.children[name]
if child.is_terminal:
effective_path = '/'.join([*path_components, name])
effective_lang_names = child.langs if len(child.langs) != 0 else all_lang_names
category = child.category
root = UncanonicalizedSourceRoot(effective_path, effective_lang_names, category)
uncanonicalized_source_roots.add(root)
traverse_helper(child, path_components if name == '^' else [*path_components, name])

traverse_helper(self._root, [])
return uncanonicalized_source_roots

def find(self, path):
"""Find the source root for the given path."""
keys = ['^'] + path.split(os.path.sep)
Expand All @@ -377,6 +401,7 @@ def find(self, path):
else:
node = child
j += 1

if node.is_terminal:
if j == 1: # The match was on the root itself.
path = ''
Expand Down
47 changes: 46 additions & 1 deletion tests/python/pants_test/source/test_source_root.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Licensed under the Apache License, Version 2.0 (see LICENSE).

from pants.source.source_root import (SourceRoot, SourceRootCategories, SourceRootConfig,
SourceRootFactory, SourceRootTrie)
SourceRootFactory, SourceRootTrie, UncanonicalizedSourceRoot)
from pants_test.test_base import TestBase


Expand Down Expand Up @@ -91,6 +91,51 @@ def root(path, langs):
self.assertEqual(root('src/go/src', ('go',)),
trie.find('src/go/src/foo/bar/baz.go'))

def test_source_root_trie_traverse(self):
def make_trie() -> SourceRootTrie:
return SourceRootTrie(SourceRootFactory({
'jvm': ('java', 'scala'),
'py': ('python',)
}))

def root(path, langs):
return UncanonicalizedSourceRoot(path, langs, UNKNOWN)

trie = make_trie()
self.assertEqual(set(), trie.traverse())

trie.add_pattern('src/*')
trie.add_pattern('src/main/*')
self.assertEqual({
root('src/*', ('jvm','py')),
root('src/main/*', ('jvm', 'py'))
}, trie.traverse())

trie = make_trie()
trie.add_pattern('*')
trie.add_pattern('src/*/code')
trie.add_pattern('src/main/*/code')
trie.add_pattern('src/main/*')
trie.add_pattern('src/main/*/foo')
self.assertEqual({
root('*', ('jvm', 'py')),
root('src/*/code', ('jvm', 'py')),
root('src/main/*/code', ('jvm', 'py')),
root('src/main/*', ('jvm', 'py')),
root('src/main/*', ('jvm', 'py')),
root('src/main/*/foo', ('jvm', 'py'))
}, trie.traverse())

trie = make_trie()
trie.add_fixed('src/scala-source-code', ('scala',))
trie.add_pattern('src/*/code')
trie.add_pattern('src/main/*/code')
self.assertEqual({
root('src/*/code', ('jvm', 'py')),
root('src/scala-source-code', ('scala',)),
root('src/main/*/code', ('jvm', 'py'))
}, trie.traverse())

def test_fixed_source_root_at_buildroot(self):
trie = SourceRootTrie(SourceRootFactory({}))
trie.add_fixed('', ('proto',))
Expand Down

0 comments on commit 4f0d075

Please sign in to comment.