Skip to content

Commit

Permalink
Add traverse method to SourceRootTrie
Browse files Browse the repository at this point in the history
This will return a set of tuples of all valid paths (with filesystem
globbing included), and the languages each path applies to, that a
given SourceRootTrie instance describes. This will be useful for
implementing a v2 engine-compatible way of getting a list of all source
files, but is a standalone change in and of itself.
  • Loading branch information
gshuflin committed Aug 15, 2019
1 parent aa47d93 commit 4bf738e
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 0 deletions.
20 changes: 20 additions & 0 deletions src/python/pants/source/source_root.py
Expand Up @@ -3,6 +3,7 @@

import os
from collections import namedtuple
from typing import Set, Tuple, Sequence

from pants.base.project_tree_factory import get_project_tree
from pants.subsystem.subsystem import Subsystem
Expand Down Expand Up @@ -361,6 +362,24 @@ def _do_add_pattern(self, pattern, langs, category):
node.category = category
node.is_terminal = True

def traverse(self) -> Set[Tuple[str, Tuple[str, ...]]]:
paths_and_langs = set()

lang_canonicalizations = self._source_root_factory._lang_canonicalizations
all_lang_names = tuple(lang_canonicalizations.keys())

def traverse_helper(node, path_components):
for name in node.children:
child = node.children[name]
if child.is_terminal:
effective_path = '/'.join([*path_components, name])
effective_lang_names = child.langs if len(child.langs) != 0 else all_lang_names
paths_and_langs.add((effective_path, effective_lang_names))
traverse_helper(child, path_components if name == '^' else [*path_components, name])

traverse_helper(self._root, [])
return paths_and_langs

def find(self, path):
"""Find the source root for the given path."""
keys = ['^'] + path.split(os.path.sep)
Expand All @@ -377,6 +396,7 @@ def find(self, path):
else:
node = child
j += 1

if node.is_terminal:
if j == 1: # The match was on the root itself.
path = ''
Expand Down
42 changes: 42 additions & 0 deletions tests/python/pants_test/source/test_source_root.py
Expand Up @@ -91,6 +91,48 @@ def root(path, langs):
self.assertEqual(root('src/go/src', ('go',)),
trie.find('src/go/src/foo/bar/baz.go'))

def test_source_root_trie_traverse(self):
def make_trie() -> SourceRootTrie:
return SourceRootTrie(SourceRootFactory({
'jvm': ('java', 'scala'),
'py': ('python',)
}))

trie = make_trie()
self.assertEqual(set(), trie.traverse())

trie.add_pattern('src/*')
trie.add_pattern('src/main/*')
self.assertEqual({
('src/*', ('jvm','py')),
('src/main/*', ('jvm', 'py'))
}, trie.traverse())

trie = make_trie()
trie.add_pattern('*')
trie.add_pattern('src/*/code')
trie.add_pattern('src/main/*/code')
trie.add_pattern('src/main/*')
trie.add_pattern('src/main/*/foo')
self.assertEqual({
('*', ('jvm', 'py')),
('src/*/code', ('jvm', 'py')),
('src/main/*/code', ('jvm', 'py')),
('src/main/*', ('jvm', 'py')),
('src/main/*', ('jvm', 'py')),
('src/main/*/foo', ('jvm', 'py'))
}, trie.traverse())

trie = make_trie()
trie.add_fixed('src/scala-source-code', ('scala',))
trie.add_pattern('src/*/code')
trie.add_pattern('src/main/*/code')
self.assertEqual({
('src/*/code', ('jvm', 'py')),
('src/scala-source-code', ('scala',)),
('src/main/*/code', ('jvm', 'py'))
}, trie.traverse())

def test_fixed_source_root_at_buildroot(self):
trie = SourceRootTrie(SourceRootFactory({}))
trie.add_fixed('', ('proto',))
Expand Down

0 comments on commit 4bf738e

Please sign in to comment.