Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[package] make GlobGroup a public concept (#56238)
Summary: Pull Request resolved: #56238 It's already functionally public due to `extern` and `mock`, but exposing the underlying implementation makes extending PackageExporter easier. Changed the underscores, expose on `torch.package`, add docs, etc. Differential Revision: D27817013 Test Plan: Imported from OSS Reviewed By: Lilyjjo Pulled By: suo fbshipit-source-id: e39199e7cb5242a8bfb815777e4bb82462864027
- Loading branch information
1 parent
1ec12fd
commit 8d4e6c9
Showing
8 changed files
with
203 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
from typing import Iterable | ||
|
||
from torch.package import GlobGroup | ||
from torch.testing._internal.common_utils import run_tests | ||
|
||
try: | ||
from .common import PackageTestCase | ||
except ImportError: | ||
# Support the case where we run this file directly. | ||
from common import PackageTestCase # type: ignore | ||
|
||
|
||
class TestGlobGroup(PackageTestCase): | ||
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]): | ||
for candidate in candidates: | ||
self.assertTrue(glob.matches(candidate)) | ||
|
||
def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]): | ||
for candidate in candidates: | ||
self.assertFalse(glob.matches(candidate)) | ||
|
||
def test_one_star(self): | ||
glob_group = GlobGroup("torch.*") | ||
self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"]) | ||
self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"]) | ||
|
||
def test_one_star_middle(self): | ||
glob_group = GlobGroup("foo.*.bar") | ||
self.assertMatchesGlob(glob_group, ["foo.q.bar", "foo.foo.bar"]) | ||
self.assertNotMatchesGlob( | ||
glob_group, | ||
[ | ||
"foo.bar", | ||
"foo.foo", | ||
"outer.foo.inner.bar", | ||
"foo.q.bar.more", | ||
"foo.one.two.bar", | ||
], | ||
) | ||
|
||
def test_one_star_partial(self): | ||
glob_group = GlobGroup("fo*.bar") | ||
self.assertMatchesGlob(glob_group, ["fo.bar", "foo.bar", "foobar.bar"]) | ||
self.assertNotMatchesGlob(glob_group, ["oij.bar", "f.bar", "foo"]) | ||
|
||
def test_one_star_multiple_in_component(self): | ||
glob_group = GlobGroup("foo/a*.htm*", separator="/") | ||
self.assertMatchesGlob(glob_group, ["foo/a.html", "foo/a.htm", "foo/abc.html"]) | ||
|
||
def test_one_star_partial_extension(self): | ||
glob_group = GlobGroup("foo/*.txt", separator="/") | ||
self.assertMatchesGlob( | ||
glob_group, ["foo/hello.txt", "foo/goodbye.txt", "foo/.txt"] | ||
) | ||
self.assertNotMatchesGlob( | ||
glob_group, ["foo/bar/hello.txt", "bar/foo/hello.txt"] | ||
) | ||
|
||
def test_two_star(self): | ||
glob_group = GlobGroup("torch.**") | ||
self.assertMatchesGlob( | ||
glob_group, ["torch.foo", "torch.bar", "torch.foo.bar", "torch"] | ||
) | ||
self.assertNotMatchesGlob(glob_group, ["what.torch", "torchvision"]) | ||
|
||
def test_two_star_end(self): | ||
glob_group = GlobGroup("**.torch") | ||
self.assertMatchesGlob(glob_group, ["torch", "bar.torch"]) | ||
self.assertNotMatchesGlob(glob_group, ["visiontorch"]) | ||
|
||
def test_two_star_middle(self): | ||
glob_group = GlobGroup("foo.**.baz") | ||
self.assertMatchesGlob( | ||
glob_group, ["foo.baz", "foo.bar.baz", "foo.bar1.bar2.baz"] | ||
) | ||
self.assertNotMatchesGlob(glob_group, ["foobaz", "foo.bar.baz.z"]) | ||
|
||
def test_two_star_multiple(self): | ||
glob_group = GlobGroup("**/bar/**/*.txt", separator="/") | ||
self.assertMatchesGlob( | ||
glob_group, ["bar/baz.txt", "a/bar/b.txt", "bar/foo/c.txt"] | ||
) | ||
self.assertNotMatchesGlob(glob_group, ["baz.txt", "a/b.txt"]) | ||
|
||
def test_raw_two_star(self): | ||
glob_group = GlobGroup("**") | ||
self.assertMatchesGlob(glob_group, ["bar", "foo.bar", "ab.c.d.e"]) | ||
self.assertNotMatchesGlob(glob_group, [""]) | ||
|
||
def test_invalid_raw(self): | ||
with self.assertRaises(ValueError): | ||
GlobGroup("a.**b") | ||
|
||
def test_exclude(self): | ||
glob_group = GlobGroup("torch.**", exclude=["torch.**.foo"]) | ||
self.assertMatchesGlob( | ||
glob_group, | ||
["torch", "torch.bar", "torch.barfoo"], | ||
) | ||
self.assertNotMatchesGlob( | ||
glob_group, | ||
["torch.foo", "torch.some.foo"], | ||
) | ||
|
||
def test_exclude_from_all(self): | ||
glob_group = GlobGroup("**", exclude=["foo.**", "bar.**"]) | ||
self.assertMatchesGlob(glob_group, ["a", "hello", "anything.really"]) | ||
self.assertNotMatchesGlob(glob_group, ["foo.bar", "foo.bar.baz"]) | ||
|
||
def test_list_include_exclude(self): | ||
glob_group = GlobGroup(["foo", "bar.**"], exclude=["bar.baz", "bar.qux"]) | ||
self.assertMatchesGlob(glob_group, ["foo", "bar.other", "bar.bazother"]) | ||
self.assertNotMatchesGlob(glob_group, ["bar.baz", "bar.qux"]) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from .glob_group import GlobGroup | ||
from .importer import ( | ||
Importer, | ||
ObjMismatchError, | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import re | ||
from typing import Iterable, Union | ||
|
||
GlobPattern = Union[str, Iterable[str]] | ||
|
||
|
||
class GlobGroup: | ||
"""A set of patterns that candidate strings will be matched against. | ||
A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz". | ||
A pattern contains one or more segments. Segments can be: | ||
- A literal string (e.g. "foo"), which matches exactly. | ||
- A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches | ||
any string, including the empty string. | ||
- A double wildcard ("**"). This matches against zero or more complete segments. | ||
Examples: | ||
``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``. | ||
``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``. | ||
``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules. | ||
A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and | ||
none of the ``exclude`` patterns. | ||
Args: | ||
include (Union[str, Iterable[str]]): A string or list of strings, | ||
each representing a pattern to be matched against. A candidate | ||
will match if it matches *any* include pattern | ||
exclude (Union[str, Iterable[str]]): A string or list of strings, | ||
each representing a pattern to be matched against. A candidate | ||
will be excluded from matching if it matches *any* exclude pattern. | ||
separator (str): A string that delimits segments in candidates and | ||
patterns. By default this is "." which corresponds to how modules are | ||
named in Python. Another common value for this is "/", which is | ||
the Unix path separator. | ||
""" | ||
def __init__( | ||
self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "." | ||
): | ||
self._dbg = f"GlobGroup(include={include}, exclude={exclude})" | ||
self.include = GlobGroup._glob_list(include, separator) | ||
self.exclude = GlobGroup._glob_list(exclude, separator) | ||
self.separator = separator | ||
|
||
def __str__(self): | ||
return self._dbg | ||
|
||
def matches(self, candidate: str) -> bool: | ||
candidate = self.separator + candidate | ||
return any(p.fullmatch(candidate) for p in self.include) and all( | ||
not p.fullmatch(candidate) for p in self.exclude | ||
) | ||
|
||
@staticmethod | ||
def _glob_list(elems: GlobPattern, separator: str = "."): | ||
if isinstance(elems, str): | ||
return [GlobGroup._glob_to_re(elems, separator)] | ||
else: | ||
return [GlobGroup._glob_to_re(e, separator) for e in elems] | ||
|
||
@staticmethod | ||
def _glob_to_re(pattern: str, separator: str = "."): | ||
# to avoid corner cases for the first component, we prefix the candidate string | ||
# with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator | ||
def component_to_re(component): | ||
if "**" in component: | ||
if component == "**": | ||
return "(" + re.escape(separator) + "[^" + separator + "]+)*" | ||
else: | ||
raise ValueError("** can only appear as an entire path segment") | ||
else: | ||
return re.escape(separator) + ("[^" + separator + "]*").join( | ||
re.escape(x) for x in component.split("*") | ||
) | ||
|
||
result = "".join(component_to_re(c) for c in pattern.split(separator)) | ||
return re.compile(result) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters