diff --git a/Lib/collections/__init__.py b/Lib/collections/__init__.py index 2eee4c70955513..febab521629228 100644 --- a/Lib/collections/__init__.py +++ b/Lib/collections/__init__.py @@ -328,14 +328,14 @@ def __ior__(self, other): return self def __or__(self, other): - if not isinstance(other, dict): + if not isinstance(other, (dict, frozendict)): return NotImplemented new = self.__class__(self) new.update(other) return new def __ror__(self, other): - if not isinstance(other, dict): + if not isinstance(other, (dict, frozendict)): return NotImplemented new = self.__class__(other) new.update(self) @@ -1216,14 +1216,14 @@ def __repr__(self): def __or__(self, other): if isinstance(other, UserDict): return self.__class__(self.data | other.data) - if isinstance(other, dict): + if isinstance(other, (dict, frozendict)): return self.__class__(self.data | other) return NotImplemented def __ror__(self, other): if isinstance(other, UserDict): return self.__class__(other.data | self.data) - if isinstance(other, dict): + if isinstance(other, (dict, frozendict)): return self.__class__(other | self.data) return NotImplemented diff --git a/Lib/test/test_ordered_dict.py b/Lib/test/test_ordered_dict.py index 4204a6a47d2a81..7963995a575744 100644 --- a/Lib/test/test_ordered_dict.py +++ b/Lib/test/test_ordered_dict.py @@ -698,6 +698,7 @@ def test_merge_operator(self): d |= list(b.items()) expected = OrderedDict({0: 0, 1: 1, 2: 2, 3: 3}) self.assertEqual(a | dict(b), expected) + self.assertEqual(a | frozendict(b), expected) self.assertEqual(a | b, expected) self.assertEqual(c, expected) self.assertEqual(d, expected) @@ -706,12 +707,15 @@ def test_merge_operator(self): c |= a expected = OrderedDict({1: 1, 2: 1, 3: 3, 0: 0}) self.assertEqual(dict(b) | a, expected) + self.assertEqual(frozendict(b) | a, expected) self.assertEqual(b | a, expected) self.assertEqual(c, expected) self.assertIs(type(a | b), OrderedDict) self.assertIs(type(dict(a) | b), OrderedDict) + self.assertIs(type(frozendict(a) | b), frozendict) # BUG: should be OrderedDict self.assertIs(type(a | dict(b)), OrderedDict) + self.assertIs(type(a | frozendict(b)), OrderedDict) expected = a.copy() a |= () diff --git a/Lib/test/test_userdict.py b/Lib/test/test_userdict.py index 13285c9b2a3b7f..c60135ca5a12a8 100644 --- a/Lib/test/test_userdict.py +++ b/Lib/test/test_userdict.py @@ -245,7 +245,7 @@ class G(collections.UserDict): test_repr_deep = mapping_tests.TestHashMappingProtocol.test_repr_deep def test_mixed_or(self): - for t in UserDict, dict, types.MappingProxyType: + for t in UserDict, dict, frozendict, types.MappingProxyType: with self.subTest(t.__name__): u = UserDict({0: 'a', 1: 'b'}) | t({1: 'c', 2: 'd'}) self.assertEqual(u, {0: 'a', 1: 'c', 2: 'd'}) @@ -276,7 +276,7 @@ def test_mixed_or(self): self.assertIs(type(u), UserDictSubclass) def test_mixed_ior(self): - for t in UserDict, dict, types.MappingProxyType: + for t in UserDict, dict, frozendict, types.MappingProxyType: with self.subTest(t.__name__): u = u2 = UserDict({0: 'a', 1: 'b'}) u |= t({1: 'c', 2: 'd'}) diff --git a/Objects/odictobject.c b/Objects/odictobject.c index 25928028919c9c..21486806653ad0 100644 --- a/Objects/odictobject.c +++ b/Objects/odictobject.c @@ -906,7 +906,7 @@ odict_or(PyObject *left, PyObject *right) type = Py_TYPE(right); other = left; } - if (!PyDict_Check(other)) { + if (!PyAnyDict_Check(other)) { Py_RETURN_NOTIMPLEMENTED; } PyObject *new = PyObject_CallOneArg((PyObject*)type, left); @@ -2271,7 +2271,7 @@ static int mutablemapping_update_arg(PyObject *self, PyObject *arg) { int res = 0; - if (PyDict_CheckExact(arg)) { + if (PyAnyDict_CheckExact(arg)) { PyObject *items = PyDict_Items(arg); if (items == NULL) { return -1;