diff --git a/root_pandas/readwrite.py b/root_pandas/readwrite.py index 5fc2083..dc7873f 100644 --- a/root_pandas/readwrite.py +++ b/root_pandas/readwrite.py @@ -7,7 +7,7 @@ from numpy.lib.recfunctions import append_fields from pandas import DataFrame, RangeIndex from root_numpy import root2array, list_trees -from fnmatch import fnmatch +import fnmatch from root_numpy import list_branches from root_numpy.extern.six import string_types import itertools @@ -59,17 +59,24 @@ def get_nonscalar_columns(array): def get_matching_variables(branches, patterns, fail=True): - selected = [] - - for p in patterns: + # Convert branches to a set to make x "in branches" O(1) on average + branches = set(branches) + patterns = set(patterns) + # Find any trivial matches + selected = list(branches.intersection(patterns)) + # Any matches that weren't trivial need to be looped over... + for pattern in patterns.difference(selected): found = False - for b in branches: - if fnmatch(b, p): + # Avoid using fnmatch if the pattern if possible + if re.findall(r'(\*)|(\?)|(\[.*\])|(\[\!.*\])', pattern): + for match in fnmatch.filter(branches, pattern): found = True - if fnmatch(b, p) and b not in selected: - selected.append(b) + if match not in selected: + selected.append(match) + elif pattern in branches: + raise NotImplementedError('I think this is impossible?') if not found and fail: - raise ValueError("Pattern '{}' didn't match any branch".format(p)) + raise ValueError("Pattern '{}' didn't match any branch".format(pattern)) return selected diff --git a/tests/test.py b/tests/test.py index 4b3b99f..792f672 100644 --- a/tests/test.py +++ b/tests/test.py @@ -230,6 +230,18 @@ def test_nonscalar_columns(): os.remove(path) +def test_get_matching_variables_performance(): + """Performance regression test for #59""" + import random + import string + import root_pandas.readwrite + for n in [10, 100, 1000, 10000]: + branches = [' '.join(random.sample(string.ascii_letters*100, k=100)) for i in range(n)] + patterns = [' '.join(random.sample(string.ascii_letters*100, k=100)) for i in range(n)] + root_pandas.readwrite.get_matching_variables(branches, patterns, fail=False) + root_pandas.readwrite.get_matching_variables(branches, branches, fail=False) + + def test_noexpand_prefix(): xs = np.array([1, 2, 3]) df = pd.DataFrame({'x': xs})