Skip to content

Commit

Permalink
feat: add include option to snapshots, similar to exclude (#797)
Browse files Browse the repository at this point in the history
  • Loading branch information
noahnu committed Aug 28, 2023
1 parent 2a7f43d commit d3f891e
Show file tree
Hide file tree
Showing 11 changed files with 126 additions and 11 deletions.
47 changes: 41 additions & 6 deletions README.md
Expand Up @@ -99,20 +99,32 @@ If you want to limit what properties are serialized at a class type level you co

```py
def limit_foo_attrs(prop, path):
allowed_foo_attrs = {"only", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs
allowed_foo_attrs = {"do", "not", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs

def test_bar(snapshot):
actual = Foo(...)
assert actual == snapshot(exclude=limit_foo_attrs)
```

**B**. Or override the `__dir__` implementation to control the attribute list.
**B**. Provide a filter function to the snapshot [include](#include) configuration option.

```py
def limit_foo_attrs(prop, path):
allowed_foo_attrs = {"only", "serialize", "these", "attrs"}
return isinstance(path[-1][1], Foo) and prop in allowed_foo_attrs

def test_bar(snapshot):
actual = Foo(...)
assert actual == snapshot(include=limit_foo_attrs)
```

**C**. Or override the `__dir__` implementation to control the attribute list.

```py
class Foo:
def __dir__(self):
return ["only", "serialize", "these", "attrs"]
def __dir__(self):
return ["only", "serialize", "these", "attrs"]

def test_bar(snapshot):
actual = Foo(...)
Expand Down Expand Up @@ -211,7 +223,7 @@ Only runs replacement for objects at a matching path where the value of the mapp
This allows you to filter out object properties from the serialized snapshot.

The exclude parameter takes a filter function that accepts two keyword arguments.
It should return `true` or `false` if the property should be excluded or included respectively.
It should return `true` if the property should be excluded, or `false` if the property should be included.

| Argument | Description |
| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
Expand Down Expand Up @@ -278,6 +290,29 @@ def test_bar(snapshot):
# ---
```

#### `include`

This allows you filter an object's properties to a subset using a predicate. This is the opposite of [exclude](#exclude). All the same property filters supporterd by [exclude](#exclude) are supported for `include`.

The include parameter takes a filter function that accepts two keyword arguments.
It should return `true` if the property should be include, or `false` if the property should not be included.

| Argument | Description |
| -------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
| `prop` | Current property on the object, could be any hashable value that can be used to retrieve a value e.g. `1`, `"prop_str"`, `SomeHashableObject` |
| `path` | Ordered path traversed to the current value e.g. `(("a", dict), ("b", dict))` from `{ "a": { "b": { "c": 1 } } }`}

Note that `include` has some caveats which make it a bit more difficult to use than `exclude`. Both `include` and `exclude` are evaluated for each key of an object before traversing down nested paths. This means if you want to include a nested path, you must include all parents of the nested path, otherwise the nested child will never be reached to be evaluated against the include predicate. For example:

```py
obj = {
"nested": { "key": True }
}
assert obj == snapshot(include=paths("nested", "nested.key"))
```

The extra "nested" is required, otherwise the nested dictionary will never be searched -- it'd get pruned too early.

#### `extension_class`

This is a way to modify how the snapshot matches and serializes your data in a single assertion.
Expand Down
9 changes: 8 additions & 1 deletion src/syrupy/assertion.py
Expand Up @@ -65,6 +65,10 @@ class SnapshotAssertion:
init=False,
default=None,
)
_include: Optional["PropertyFilter"] = field(
init=False,
default=None,
)
_custom_index: Optional[str] = field(
init=False,
default=None,
Expand Down Expand Up @@ -180,7 +184,7 @@ def assert_match(self, data: "SerializableData") -> None:

def _serialize(self, data: "SerializableData") -> "SerializedData":
return self.extension.serialize(
data, exclude=self._exclude, matcher=self.__matcher
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
)

def get_assert_diff(self) -> List[str]:
Expand Down Expand Up @@ -233,6 +237,7 @@ def __call__(
*,
diff: Optional["SnapshotIndex"] = None,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
matcher: Optional["PropertyMatcher"] = None,
name: Optional["SnapshotIndex"] = None,
Expand All @@ -242,6 +247,8 @@ def __call__(
"""
if exclude:
self.__with_prop("_exclude", exclude)
if include:
self.__with_prop("_include", include)
if extension_class:
self.__with_prop("_extension", self.__init_extension(extension_class))
if matcher:
Expand Down
17 changes: 14 additions & 3 deletions src/syrupy/extensions/amber/serializer.py
Expand Up @@ -203,6 +203,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> str:
"""
Expand All @@ -211,7 +212,9 @@ def serialize(
same new line control characters. Example snapshots generated on windows os
should not break when running the tests on a unix based system and vice versa.
"""
serialized = cls._serialize(data, exclude=exclude, matcher=matcher)
serialized = cls._serialize(
data, exclude=exclude, include=include, matcher=matcher
)
return serialized.replace("\r\n", "\n").replace("\r", "\n")

@classmethod
Expand All @@ -221,6 +224,7 @@ def _serialize(
*,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
path: "PropertyPath" = (),
visited: Optional[Set[Any]] = None,
Expand All @@ -235,6 +239,7 @@ def _serialize(
"data": data,
"depth": depth,
"exclude": exclude,
"include": include,
"matcher": matcher,
"path": path,
"visited": {*visited, data_id},
Expand Down Expand Up @@ -400,6 +405,7 @@ def serialize_custom_iterable(
close_paren: Optional[str] = None,
depth: int = 0,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
path: "PropertyPath" = (),
separator: Optional[str] = None,
serialize_key: bool = False,
Expand All @@ -414,7 +420,8 @@ def serialize_custom_iterable(
key_values = (
(key, get_value(data, key))
for key in keys
if not exclude or not exclude(prop=key, path=path)
if (not exclude or not exclude(prop=key, path=path))
and (not include or include(prop=key, path=path))
)
entries = (
entry
Expand All @@ -433,7 +440,11 @@ def key_str(key: "PropertyName") -> str:

def value_str(key: "PropertyName", value: "SerializableData") -> str:
serialized = cls._serialize(
data=value, exclude=exclude, path=(*path, (key, type(value))), **kwargs
data=value,
exclude=exclude,
include=include,
path=(*path, (key, type(value))),
**kwargs,
)
return serialized if separator is None else serialized.lstrip(cls._indent)

Expand Down
1 change: 1 addition & 0 deletions src/syrupy/extensions/base.py
Expand Up @@ -65,6 +65,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
"""
Expand Down
14 changes: 13 additions & 1 deletion src/syrupy/extensions/json/__init__.py
Expand Up @@ -55,6 +55,7 @@ def _filter(
depth: int = 0,
path: "PropertyPath",
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
visited: Optional[Set[Any]] = None,
) -> "SerializableData":
Expand All @@ -80,13 +81,16 @@ def _filter(
value = data[key]
if exclude and exclude(prop=key, path=path):
continue
if include and not include(prop=key, path=path):
continue
if not isinstance(key, (str,)):
continue
filtered_dct[key] = cls._filter(
data=value,
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -101,6 +105,7 @@ def _filter(
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -118,6 +123,7 @@ def _filter(
depth=depth + 1,
path=(*path, (key, type(value))),
exclude=exclude,
include=include,
matcher=matcher,
visited={*visited, data_id},
)
Expand All @@ -137,9 +143,15 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
data = self._filter(
data=data, depth=0, path=(), exclude=exclude, matcher=matcher
data=data,
depth=0,
path=(),
exclude=exclude,
include=include,
matcher=matcher,
)
return json.dumps(data, indent=2, ensure_ascii=False, sort_keys=False) + "\n"
1 change: 1 addition & 0 deletions src/syrupy/extensions/single_file.py
Expand Up @@ -47,6 +47,7 @@ def serialize(
data: "SerializableData",
*,
exclude: Optional["PropertyFilter"] = None,
include: Optional["PropertyFilter"] = None,
matcher: Optional["PropertyMatcher"] = None,
) -> "SerializedData":
return self.get_supported_dataclass()(data)
Expand Down
Expand Up @@ -35,3 +35,18 @@
}),
})
# ---
# name: test_only_includes_expected_props
dict({
'date': 'utc',
0: 'some value',
})
# ---
# name: test_only_includes_expected_props.1
dict({
'date': 'utc',
'nested': dict({
'id': 4,
}),
0: 'some value',
})
# ---
12 changes: 12 additions & 0 deletions tests/syrupy/extensions/amber/test_amber_filters.py
Expand Up @@ -38,6 +38,18 @@ def test_filters_expected_props(snapshot):
assert actual == snapshot(exclude=props("0", "date", "id"))


def test_only_includes_expected_props(snapshot):
actual = {
0: "some value",
"date": "utc",
"nested": {"id": 4, "other": "value"},
"list": [1, 2],
}
# Note that "id" won't get included because "nested" (its parent) is not included.
assert actual == snapshot(include=props("0", "date", "id"))
assert actual == snapshot(include=paths("0", "date", "nested", "nested.id"))


@pytest.mark.parametrize(
"predicate", [paths("exclude_me", "nested.exclude_me"), props("exclude_me")]
)
Expand Down
@@ -0,0 +1,4 @@
{
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"id": 123456789
}
@@ -0,0 +1,4 @@
{
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"id": 123456789
}
13 changes: 13 additions & 0 deletions tests/syrupy/extensions/json/test_json_filters.py
Expand Up @@ -46,6 +46,19 @@ def test_exclude_simple(snapshot_json):
assert snapshot_json(exclude=paths("id", "foo")) == content


def test_include_simple(snapshot_json):
content = {
"id": 123456789,
"foo": "__SHOULD_BE_REMOVED_FROM_JSON__",
"I'm": "still alive",
"nested": {
"foo": "is still alive",
},
}
assert snapshot_json(include=props("id", "foo")) == content
assert snapshot_json(include=paths("id", "foo")) == content


def test_exclude_nested(snapshot_json):
content = {
"a": "b",
Expand Down

0 comments on commit d3f891e

Please sign in to comment.