diff --git a/pydantic/_internal/_model_construction.py b/pydantic/_internal/_model_construction.py index 15dbb3ee1c..4fa66925e1 100644 --- a/pydantic/_internal/_model_construction.py +++ b/pydantic/_internal/_model_construction.py @@ -587,8 +587,39 @@ def generate_model_signature( return Signature(parameters=list(merged_params.values()), return_annotation=None) -class _PydanticWeakRef(weakref.ReferenceType): - pass +class _PydanticWeakRef: + """Wrapper for `weakref.ref` that enables `pickle` serialization. + + Cloudpickle fails to serialize `weakref.ref` objects due to an arcane error related + to abstract base classes (`abc.ABC`). This class works around the issue by wrapping + `weakref.ref` instead of subclassing it. + + See https://github.com/pydantic/pydantic/issues/6763 for context. + + Semantics: + - If not pickled, behaves the same as a `weakref.ref`. + - If pickled along with the referenced object, the same `weakref.ref` behavior + will be maintained between them after unpickling. + - If pickled without the referenced object, after unpickling the underlying + reference will be cleared (`__call__` will always return `None`). + """ + + def __init__(self, obj: Any): + if obj is None: + # The object will be `None` upon deserialization if the serialized weakref + # had lost its underlying object. + self._wr = None + else: + self._wr = weakref.ref(obj) + + def __call__(self) -> Any: + if self._wr is None: + return None + else: + return self._wr() + + def __reduce__(self) -> tuple[Callable, tuple[weakref.ReferenceType | None]]: + return _PydanticWeakRef, (self(),) def build_lenient_weakvaluedict(d: dict[str, Any] | None) -> dict[str, Any] | None: diff --git a/tests/test_pickle_pydantic_weakref.py b/tests/test_pickle_pydantic_weakref.py new file mode 100644 index 0000000000..bba590d8d3 --- /dev/null +++ b/tests/test_pickle_pydantic_weakref.py @@ -0,0 +1,52 @@ +import gc +import pickle + +from pydantic._internal._model_construction import _PydanticWeakRef + + +class IntWrapper: + def __init__(self, v: int): + self._v = v + + def get(self) -> int: + return self._v + + def __eq__(self, other: 'IntWrapper') -> bool: + return self.get() == other.get() + + +def test_pickle_pydantic_weakref(): + obj1 = IntWrapper(1) + ref1 = _PydanticWeakRef(obj1) + assert ref1() is obj1 + + obj2 = IntWrapper(2) + ref2 = _PydanticWeakRef(obj2) + assert ref2() is obj2 + + ref3 = _PydanticWeakRef(IntWrapper(3)) + gc.collect() # PyPy does not use reference counting and always relies on GC. + assert ref3() is None + + d = { + # Hold a hard reference to the underlying object for ref1 that will also + # be pickled. + 'hard_ref': obj1, + # ref1's underlying object has a hard reference in the pickled object so it + # should maintain the reference after deserialization. + 'has_hard_ref': ref1, + # ref2's underlying object has no hard reference in the pickled object so it + # should be `None` after deserialization. + 'has_no_hard_ref': ref2, + # ref3's underlying object had already gone out of scope before pickling so it + # should be `None` after deserialization. + 'ref_out_of_scope': ref3, + } + + loaded = pickle.loads(pickle.dumps(d)) + gc.collect() # PyPy does not use reference counting and always relies on GC. + + assert loaded['hard_ref'] == IntWrapper(1) + assert loaded['has_hard_ref']() is loaded['hard_ref'] + assert loaded['has_no_hard_ref']() is None + assert loaded['ref_out_of_scope']() is None