Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[package] make GlobGroup a public concept #56238

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 0 additions & 35 deletions test/package/test_dependency_api.py
Expand Up @@ -180,41 +180,6 @@ def test_mock_glob_allow_empty(self):
exporter.mock(include=["package_a.*"], allow_empty=False)
exporter.save_module("package_b.subpackage")

def test_module_glob(self):
from torch.package.package_exporter import _GlobGroup

def check(include, exclude, should_match, should_not_match):
x = _GlobGroup(include, exclude)
for e in should_match:
self.assertTrue(x.matches(e))
for e in should_not_match:
self.assertFalse(x.matches(e))

check(
"torch.*",
[],
["torch.foo", "torch.bar"],
["tor.foo", "torch.foo.bar", "torch"],
)
check(
"torch.**",
[],
["torch.foo", "torch.bar", "torch.foo.bar", "torch"],
["what.torch", "torchvision"],
)
check("torch.*.foo", [], ["torch.w.foo"], ["torch.hi.bar.baz"])
check(
"torch.**.foo", [], ["torch.w.foo", "torch.hi.bar.foo"], ["torch.f.foo.z"]
)
check("torch*", [], ["torch", "torchvision"], ["torch.f"])
check(
"torch.**",
["torch.**.foo"],
["torch", "torch.bar", "torch.barfoo"],
["torch.foo", "torch.some.foo"],
)
check("**.torch", [], ["torch", "bar.torch"], ["visiontorch"])

@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_pickle_mocked(self):
import package_a.subpackage
Expand Down
117 changes: 117 additions & 0 deletions test/package/test_glob_group.py
@@ -0,0 +1,117 @@
from typing import Iterable

from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests

try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore


class TestGlobGroup(PackageTestCase):
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertTrue(glob.matches(candidate))

def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertFalse(glob.matches(candidate))

def test_one_star(self):
glob_group = GlobGroup("torch.*")
self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"])
self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"])

def test_one_star_middle(self):
glob_group = GlobGroup("foo.*.bar")
self.assertMatchesGlob(glob_group, ["foo.q.bar", "foo.foo.bar"])
self.assertNotMatchesGlob(
glob_group,
[
"foo.bar",
"foo.foo",
"outer.foo.inner.bar",
"foo.q.bar.more",
"foo.one.two.bar",
],
)

def test_one_star_partial(self):
glob_group = GlobGroup("fo*.bar")
self.assertMatchesGlob(glob_group, ["fo.bar", "foo.bar", "foobar.bar"])
self.assertNotMatchesGlob(glob_group, ["oij.bar", "f.bar", "foo"])

def test_one_star_multiple_in_component(self):
glob_group = GlobGroup("foo/a*.htm*", separator="/")
self.assertMatchesGlob(glob_group, ["foo/a.html", "foo/a.htm", "foo/abc.html"])

def test_one_star_partial_extension(self):
glob_group = GlobGroup("foo/*.txt", separator="/")
self.assertMatchesGlob(
glob_group, ["foo/hello.txt", "foo/goodbye.txt", "foo/.txt"]
)
self.assertNotMatchesGlob(
glob_group, ["foo/bar/hello.txt", "bar/foo/hello.txt"]
)

def test_two_star(self):
glob_group = GlobGroup("torch.**")
self.assertMatchesGlob(
glob_group, ["torch.foo", "torch.bar", "torch.foo.bar", "torch"]
)
self.assertNotMatchesGlob(glob_group, ["what.torch", "torchvision"])

def test_two_star_end(self):
glob_group = GlobGroup("**.torch")
self.assertMatchesGlob(glob_group, ["torch", "bar.torch"])
self.assertNotMatchesGlob(glob_group, ["visiontorch"])

def test_two_star_middle(self):
glob_group = GlobGroup("foo.**.baz")
self.assertMatchesGlob(
glob_group, ["foo.baz", "foo.bar.baz", "foo.bar1.bar2.baz"]
)
self.assertNotMatchesGlob(glob_group, ["foobaz", "foo.bar.baz.z"])

def test_two_star_multiple(self):
glob_group = GlobGroup("**/bar/**/*.txt", separator="/")
self.assertMatchesGlob(
glob_group, ["bar/baz.txt", "a/bar/b.txt", "bar/foo/c.txt"]
)
self.assertNotMatchesGlob(glob_group, ["baz.txt", "a/b.txt"])

def test_raw_two_star(self):
glob_group = GlobGroup("**")
self.assertMatchesGlob(glob_group, ["bar", "foo.bar", "ab.c.d.e"])
self.assertNotMatchesGlob(glob_group, [""])

def test_invalid_raw(self):
with self.assertRaises(ValueError):
GlobGroup("a.**b")

def test_exclude(self):
glob_group = GlobGroup("torch.**", exclude=["torch.**.foo"])
self.assertMatchesGlob(
glob_group,
["torch", "torch.bar", "torch.barfoo"],
)
self.assertNotMatchesGlob(
glob_group,
["torch.foo", "torch.some.foo"],
)

def test_exclude_from_all(self):
glob_group = GlobGroup("**", exclude=["foo.**", "bar.**"])
self.assertMatchesGlob(glob_group, ["a", "hello", "anything.really"])
self.assertNotMatchesGlob(glob_group, ["foo.bar", "foo.bar.baz"])

def test_list_include_exclude(self):
glob_group = GlobGroup(["foo", "bar.**"], exclude=["bar.baz", "bar.qux"])
self.assertMatchesGlob(glob_group, ["foo", "bar.other", "bar.bazother"])
self.assertNotMatchesGlob(glob_group, ["bar.baz", "bar.qux"])


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torch/package/__init__.py
@@ -1,3 +1,4 @@
from .glob_group import GlobGroup
from .importer import (
Importer,
ObjMismatchError,
Expand Down
4 changes: 2 additions & 2 deletions torch/package/_file_structure_representation.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from typing import Dict, List

from ._glob_group import GlobPattern, _GlobGroup
from .glob_group import GlobPattern, GlobGroup


class Folder:
Expand Down Expand Up @@ -70,7 +70,7 @@ def _create_folder_from_file_list(
include: "GlobPattern" = "**",
exclude: "GlobPattern" = (),
) -> Folder:
glob_pattern = _GlobGroup(include, exclude, "/")
glob_pattern = GlobGroup(include, exclude=exclude, separator="/")

top_folder = Folder(filename, True)
for file in file_list:
Expand Down
48 changes: 0 additions & 48 deletions torch/package/_glob_group.py

This file was deleted.

78 changes: 78 additions & 0 deletions torch/package/glob_group.py
@@ -0,0 +1,78 @@
import re
from typing import Iterable, Union

GlobPattern = Union[str, Iterable[str]]


class GlobGroup:
"""A set of patterns that candidate strings will be matched against.

A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz".

A pattern contains one or more segments. Segments can be:
- A literal string (e.g. "foo"), which matches exactly.
- A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches
any string, including the empty string.
- A double wildcard ("**"). This matches against zero or more complete segments.

Examples:
``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``.
``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``.
``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules.

A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and
none of the ``exclude`` patterns.

Args:
include (Union[str, Iterable[str]]): A string or list of strings,
each representing a pattern to be matched against. A candidate
will match if it matches *any* include pattern
exclude (Union[str, Iterable[str]]): A string or list of strings,
each representing a pattern to be matched against. A candidate
will be excluded from matching if it matches *any* exclude pattern.
separator (str): A string that delimits segments in candidates and
patterns. By default this is "." which corresponds to how modules are
named in Python. Another common value for this is "/", which is
the Unix path separator.
"""
def __init__(
self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "."
):
self._dbg = f"GlobGroup(include={include}, exclude={exclude})"
self.include = GlobGroup._glob_list(include, separator)
self.exclude = GlobGroup._glob_list(exclude, separator)
self.separator = separator

def __str__(self):
return self._dbg

def matches(self, candidate: str) -> bool:
candidate = self.separator + candidate
return any(p.fullmatch(candidate) for p in self.include) and all(
not p.fullmatch(candidate) for p in self.exclude
)

@staticmethod
def _glob_list(elems: GlobPattern, separator: str = "."):
if isinstance(elems, str):
return [GlobGroup._glob_to_re(elems, separator)]
else:
return [GlobGroup._glob_to_re(e, separator) for e in elems]

@staticmethod
def _glob_to_re(pattern: str, separator: str = "."):
# to avoid corner cases for the first component, we prefix the candidate string
# with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator
def component_to_re(component):
if "**" in component:
if component == "**":
return "(" + re.escape(separator) + "[^" + separator + "]+)*"
else:
raise ValueError("** can only appear as an entire path segment")
else:
return re.escape(separator) + ("[^" + separator + "]*").join(
re.escape(x) for x in component.split("*")
)

result = "".join(component_to_re(c) for c in pattern.split(separator))
return re.compile(result)
8 changes: 4 additions & 4 deletions torch/package/package_exporter.py
Expand Up @@ -22,7 +22,7 @@
from torch.serialization import location_tag, normalize_storage_type

from ._file_structure_representation import Folder, _create_folder_from_file_list
from ._glob_group import GlobPattern, _GlobGroup
from .glob_group import GlobPattern, GlobGroup
from ._importlib import _normalize_path
from ._mangling import is_mangled
from ._package_pickler import create_pickler
Expand Down Expand Up @@ -476,7 +476,7 @@ def mock(

"""
self.patterns.append(
(_GlobGroup(include, exclude), self.save_mock_module, allow_empty)
(GlobGroup(include, exclude=exclude), self.save_mock_module, allow_empty)
)

def extern(
Expand Down Expand Up @@ -504,7 +504,7 @@ def extern(

"""
self.patterns.append(
(_GlobGroup(include, exclude), self.save_extern_module, allow_empty)
(GlobGroup(include, exclude=exclude), self.save_extern_module, allow_empty)
)

def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
Expand All @@ -518,7 +518,7 @@ def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
"""
self.patterns.append(
(_GlobGroup(include, exclude), self._reject_denied_module, True)
(GlobGroup(include, exclude=exclude), self._reject_denied_module, True)
)

def save_extern_module(self, module_name: str):
Expand Down
2 changes: 1 addition & 1 deletion torch/package/package_importer.py
Expand Up @@ -13,7 +13,7 @@
from torch.serialization import _get_restore_location, _maybe_decode_ascii

from ._file_structure_representation import Folder, _create_folder_from_file_list
from ._glob_group import GlobPattern
from .glob_group import GlobPattern
from ._importlib import (
_calc___package__,
_normalize_line_endings,
Expand Down