diff --git a/changelog.d/864.change.md b/changelog.d/864.change.md new file mode 100644 index 000000000..2ef8add8a --- /dev/null +++ b/changelog.d/864.change.md @@ -0,0 +1,2 @@ +`attrs.filters.include()` and `attrs.filters.exclude()` now match `attrs.Attribute` instances by identity. +Passing a field returned by `attrs.fields()` therefore only matches that exact class's field; pass a string field name to match same-named fields across classes. diff --git a/docs/examples.md b/docs/examples.md index 2393decf4..2c87994da 100644 --- a/docs/examples.md +++ b/docs/examples.md @@ -271,6 +271,7 @@ For the common case where you want to [`include`](attrs.filters.include) or [`ex Though using string names directly is convenient, mistyping attribute names will silently do the wrong thing and neither Python nor your type checker can help you. {func}`attrs.fields()` will raise an `AttributeError` when the field doesn't exist while literal string names won't. Using {func}`attrs.fields()` to get attributes is worth being recommended in most cases. +String names match all fields with that name, while fields returned from {func}`attrs.fields()` only match that exact class's field. ```{doctest} >>> asdict( diff --git a/src/attr/filters.py b/src/attr/filters.py index 689b1705a..bb65022d2 100644 --- a/src/attr/filters.py +++ b/src/attr/filters.py @@ -9,15 +9,19 @@ def _split_what(what): """ - Returns a tuple of `frozenset`s of classes and attributes. + Returns a tuple of classes, names, and attributes to match. """ return ( frozenset(cls for cls in what if isinstance(cls, type)), frozenset(cls for cls in what if isinstance(cls, str)), - frozenset(cls for cls in what if isinstance(cls, Attribute)), + tuple(cls for cls in what if isinstance(cls, Attribute)), ) +def _matches_attribute(attribute, attrs): + return any(attribute is a for a in attrs) + + def include(*what): """ Create a filter that only allows *what*. @@ -39,7 +43,7 @@ def include_(attribute, value): return ( value.__class__ in cls or attribute.name in names - or attribute in attrs + or _matches_attribute(attribute, attrs) ) return include_ @@ -66,7 +70,7 @@ def exclude_(attribute, value): return not ( value.__class__ in cls or attribute.name in names - or attribute in attrs + or _matches_attribute(attribute, attrs) ) return exclude_ diff --git a/tests/test_filters.py b/tests/test_filters.py index 08314fa88..1c56c6408 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -18,6 +18,11 @@ class C: b = attr.ib() +@attr.s +class D: + a = attr.ib() + + class TestSplitWhat: """ Tests for `_split_what`. @@ -30,7 +35,7 @@ def test_splits(self): assert ( frozenset((int, str)), frozenset(("abcd", "123")), - frozenset((fields(C).a,)), + (fields(C).a,), ) == _split_what((str, "123", fields(C).a, int, "abcd")) @@ -79,6 +84,15 @@ def test_drop_class(self, incl, value): i = include(*incl) assert i(fields(C).a, value) is False + def test_allow_attributes_by_identity(self): + """ + Attributes with the same name on other classes are not included. + """ + i = include(fields(C).a) + + assert i(fields(C).a, 42) is True + assert i(fields(D).a, 42) is False + class TestExclude: """ @@ -124,3 +138,12 @@ def test_drop_class(self, excl, value): """ e = exclude(*excl) assert e(fields(C).a, value) is False + + def test_drop_attributes_by_identity(self): + """ + Attributes with the same name on other classes are not excluded. + """ + e = exclude(fields(C).a) + + assert e(fields(C).a, 42) is False + assert e(fields(D).a, 42) is True