Skip to content

Commit

Permalink
[package] make GlobGroup a public concept (#56238)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #56238

It's already functionally public due to `extern` and `mock`, but
exposing the underlying implementation makes extending PackageExporter
easier.

Changed the underscores, expose on `torch.package`, add docs, etc.

Differential Revision: D27817013

Test Plan: Imported from OSS

Reviewed By: Lilyjjo

Pulled By: suo

fbshipit-source-id: e39199e7cb5242a8bfb815777e4bb82462864027
  • Loading branch information
suo authored and facebook-github-bot committed Apr 16, 2021
1 parent 1ec12fd commit 8d4e6c9
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 90 deletions.
35 changes: 0 additions & 35 deletions test/package/test_dependency_api.py
Expand Up @@ -184,41 +184,6 @@ def test_mock_glob_allow_empty(self):
exporter.mock(include=["package_b.*"], allow_empty=False)
exporter.save_module("package_a.subpackage")

def test_module_glob(self):
from torch.package.package_exporter import _GlobGroup

def check(include, exclude, should_match, should_not_match):
x = _GlobGroup(include, exclude)
for e in should_match:
self.assertTrue(x.matches(e))
for e in should_not_match:
self.assertFalse(x.matches(e))

check(
"torch.*",
[],
["torch.foo", "torch.bar"],
["tor.foo", "torch.foo.bar", "torch"],
)
check(
"torch.**",
[],
["torch.foo", "torch.bar", "torch.foo.bar", "torch"],
["what.torch", "torchvision"],
)
check("torch.*.foo", [], ["torch.w.foo"], ["torch.hi.bar.baz"])
check(
"torch.**.foo", [], ["torch.w.foo", "torch.hi.bar.foo"], ["torch.f.foo.z"]
)
check("torch*", [], ["torch", "torchvision"], ["torch.f"])
check(
"torch.**",
["torch.**.foo"],
["torch", "torch.bar", "torch.barfoo"],
["torch.foo", "torch.some.foo"],
)
check("**.torch", [], ["torch", "bar.torch"], ["visiontorch"])

@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_pickle_mocked(self):
import package_a.subpackage
Expand Down
117 changes: 117 additions & 0 deletions test/package/test_glob_group.py
@@ -0,0 +1,117 @@
from typing import Iterable

from torch.package import GlobGroup
from torch.testing._internal.common_utils import run_tests

try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore


class TestGlobGroup(PackageTestCase):
def assertMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertTrue(glob.matches(candidate))

def assertNotMatchesGlob(self, glob: GlobGroup, candidates: Iterable[str]):
for candidate in candidates:
self.assertFalse(glob.matches(candidate))

def test_one_star(self):
glob_group = GlobGroup("torch.*")
self.assertMatchesGlob(glob_group, ["torch.foo", "torch.bar"])
self.assertNotMatchesGlob(glob_group, ["tor.foo", "torch.foo.bar", "torch"])

def test_one_star_middle(self):
glob_group = GlobGroup("foo.*.bar")
self.assertMatchesGlob(glob_group, ["foo.q.bar", "foo.foo.bar"])
self.assertNotMatchesGlob(
glob_group,
[
"foo.bar",
"foo.foo",
"outer.foo.inner.bar",
"foo.q.bar.more",
"foo.one.two.bar",
],
)

def test_one_star_partial(self):
glob_group = GlobGroup("fo*.bar")
self.assertMatchesGlob(glob_group, ["fo.bar", "foo.bar", "foobar.bar"])
self.assertNotMatchesGlob(glob_group, ["oij.bar", "f.bar", "foo"])

def test_one_star_multiple_in_component(self):
glob_group = GlobGroup("foo/a*.htm*", separator="/")
self.assertMatchesGlob(glob_group, ["foo/a.html", "foo/a.htm", "foo/abc.html"])

def test_one_star_partial_extension(self):
glob_group = GlobGroup("foo/*.txt", separator="/")
self.assertMatchesGlob(
glob_group, ["foo/hello.txt", "foo/goodbye.txt", "foo/.txt"]
)
self.assertNotMatchesGlob(
glob_group, ["foo/bar/hello.txt", "bar/foo/hello.txt"]
)

def test_two_star(self):
glob_group = GlobGroup("torch.**")
self.assertMatchesGlob(
glob_group, ["torch.foo", "torch.bar", "torch.foo.bar", "torch"]
)
self.assertNotMatchesGlob(glob_group, ["what.torch", "torchvision"])

def test_two_star_end(self):
glob_group = GlobGroup("**.torch")
self.assertMatchesGlob(glob_group, ["torch", "bar.torch"])
self.assertNotMatchesGlob(glob_group, ["visiontorch"])

def test_two_star_middle(self):
glob_group = GlobGroup("foo.**.baz")
self.assertMatchesGlob(
glob_group, ["foo.baz", "foo.bar.baz", "foo.bar1.bar2.baz"]
)
self.assertNotMatchesGlob(glob_group, ["foobaz", "foo.bar.baz.z"])

def test_two_star_multiple(self):
glob_group = GlobGroup("**/bar/**/*.txt", separator="/")
self.assertMatchesGlob(
glob_group, ["bar/baz.txt", "a/bar/b.txt", "bar/foo/c.txt"]
)
self.assertNotMatchesGlob(glob_group, ["baz.txt", "a/b.txt"])

def test_raw_two_star(self):
glob_group = GlobGroup("**")
self.assertMatchesGlob(glob_group, ["bar", "foo.bar", "ab.c.d.e"])
self.assertNotMatchesGlob(glob_group, [""])

def test_invalid_raw(self):
with self.assertRaises(ValueError):
GlobGroup("a.**b")

def test_exclude(self):
glob_group = GlobGroup("torch.**", exclude=["torch.**.foo"])
self.assertMatchesGlob(
glob_group,
["torch", "torch.bar", "torch.barfoo"],
)
self.assertNotMatchesGlob(
glob_group,
["torch.foo", "torch.some.foo"],
)

def test_exclude_from_all(self):
glob_group = GlobGroup("**", exclude=["foo.**", "bar.**"])
self.assertMatchesGlob(glob_group, ["a", "hello", "anything.really"])
self.assertNotMatchesGlob(glob_group, ["foo.bar", "foo.bar.baz"])

def test_list_include_exclude(self):
glob_group = GlobGroup(["foo", "bar.**"], exclude=["bar.baz", "bar.qux"])
self.assertMatchesGlob(glob_group, ["foo", "bar.other", "bar.bazother"])
self.assertNotMatchesGlob(glob_group, ["bar.baz", "bar.qux"])


if __name__ == "__main__":
run_tests()
1 change: 1 addition & 0 deletions torch/package/__init__.py
@@ -1,3 +1,4 @@
from .glob_group import GlobGroup
from .importer import (
Importer,
ObjMismatchError,
Expand Down
4 changes: 2 additions & 2 deletions torch/package/_file_structure_representation.py
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
from typing import Dict, List

from ._glob_group import GlobPattern, _GlobGroup
from .glob_group import GlobPattern, GlobGroup


class Folder:
Expand Down Expand Up @@ -70,7 +70,7 @@ def _create_folder_from_file_list(
include: "GlobPattern" = "**",
exclude: "GlobPattern" = (),
) -> Folder:
glob_pattern = _GlobGroup(include, exclude, "/")
glob_pattern = GlobGroup(include, exclude=exclude, separator="/")

top_folder = Folder(filename, True)
for file in file_list:
Expand Down
48 changes: 0 additions & 48 deletions torch/package/_glob_group.py

This file was deleted.

78 changes: 78 additions & 0 deletions torch/package/glob_group.py
@@ -0,0 +1,78 @@
import re
from typing import Iterable, Union

GlobPattern = Union[str, Iterable[str]]


class GlobGroup:
"""A set of patterns that candidate strings will be matched against.
A candidate is composed of a list of segments separated by ``separator``, e.g. "foo.bar.baz".
A pattern contains one or more segments. Segments can be:
- A literal string (e.g. "foo"), which matches exactly.
- A string containing a wildcard (e.g. "torch*", or "foo*baz*"). The wildcard matches
any string, including the empty string.
- A double wildcard ("**"). This matches against zero or more complete segments.
Examples:
``torch.**``: matches ``torch`` and all its submodules, e.g. ``torch.nn`` and ``torch.nn.functional``.
``torch.*``: matches ``torch.nn`` or ``torch.functional``, but not ``torch.nn.functional``.
``torch*.**``: matches ``torch``, ``torchvision``, and all their submodules.
A candidates will match the ``GlobGroup`` if it matches any of the ``include`` patterns and
none of the ``exclude`` patterns.
Args:
include (Union[str, Iterable[str]]): A string or list of strings,
each representing a pattern to be matched against. A candidate
will match if it matches *any* include pattern
exclude (Union[str, Iterable[str]]): A string or list of strings,
each representing a pattern to be matched against. A candidate
will be excluded from matching if it matches *any* exclude pattern.
separator (str): A string that delimits segments in candidates and
patterns. By default this is "." which corresponds to how modules are
named in Python. Another common value for this is "/", which is
the Unix path separator.
"""
def __init__(
self, include: GlobPattern, *, exclude: GlobPattern = (), separator: str = "."
):
self._dbg = f"GlobGroup(include={include}, exclude={exclude})"
self.include = GlobGroup._glob_list(include, separator)
self.exclude = GlobGroup._glob_list(exclude, separator)
self.separator = separator

def __str__(self):
return self._dbg

def matches(self, candidate: str) -> bool:
candidate = self.separator + candidate
return any(p.fullmatch(candidate) for p in self.include) and all(
not p.fullmatch(candidate) for p in self.exclude
)

@staticmethod
def _glob_list(elems: GlobPattern, separator: str = "."):
if isinstance(elems, str):
return [GlobGroup._glob_to_re(elems, separator)]
else:
return [GlobGroup._glob_to_re(e, separator) for e in elems]

@staticmethod
def _glob_to_re(pattern: str, separator: str = "."):
# to avoid corner cases for the first component, we prefix the candidate string
# with '.' so `import torch` will regex against `.torch`, assuming '.' is the separator
def component_to_re(component):
if "**" in component:
if component == "**":
return "(" + re.escape(separator) + "[^" + separator + "]+)*"
else:
raise ValueError("** can only appear as an entire path segment")
else:
return re.escape(separator) + ("[^" + separator + "]*").join(
re.escape(x) for x in component.split("*")
)

result = "".join(component_to_re(c) for c in pattern.split(separator))
return re.compile(result)
8 changes: 4 additions & 4 deletions torch/package/package_exporter.py
Expand Up @@ -22,7 +22,7 @@
from torch.serialization import location_tag, normalize_storage_type

from ._file_structure_representation import Folder, _create_folder_from_file_list
from ._glob_group import GlobPattern, _GlobGroup
from .glob_group import GlobPattern, GlobGroup
from ._importlib import _normalize_path
from ._mangling import is_mangled
from ._package_pickler import create_pickler
Expand Down Expand Up @@ -476,7 +476,7 @@ def mock(
"""
self.patterns.append(
(_GlobGroup(include, exclude), self.save_mock_module, allow_empty)
(GlobGroup(include, exclude=exclude), self.save_mock_module, allow_empty)
)

def extern(
Expand Down Expand Up @@ -504,7 +504,7 @@ def extern(
"""
self.patterns.append(
(_GlobGroup(include, exclude), self.save_extern_module, allow_empty)
(GlobGroup(include, exclude=exclude), self.save_extern_module, allow_empty)
)

def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
Expand All @@ -518,7 +518,7 @@ def deny(self, include: "GlobPattern", *, exclude: "GlobPattern" = ()):
exclude (Union[List[str], str]): An optional pattern that excludes some patterns that match the include string.
"""
self.patterns.append(
(_GlobGroup(include, exclude), self._reject_denied_module, True)
(GlobGroup(include, exclude=exclude), self._reject_denied_module, True)
)

def save_extern_module(self, module_name: str):
Expand Down
2 changes: 1 addition & 1 deletion torch/package/package_importer.py
Expand Up @@ -13,7 +13,7 @@
from torch.serialization import _get_restore_location, _maybe_decode_ascii

from ._file_structure_representation import Folder, _create_folder_from_file_list
from ._glob_group import GlobPattern
from .glob_group import GlobPattern
from ._importlib import (
_calc___package__,
_normalize_line_endings,
Expand Down

0 comments on commit 8d4e6c9

Please sign in to comment.