diff --git a/Doc/library/shelve.rst b/Doc/library/shelve.rst index 23a2e0c3d0c758..23808619524056 100644 --- a/Doc/library/shelve.rst +++ b/Doc/library/shelve.rst @@ -17,7 +17,8 @@ This includes most class instances, recursive data types, and objects containing lots of shared sub-objects. The keys are ordinary strings. -.. function:: open(filename, flag='c', protocol=None, writeback=False) +.. function:: open(filename, flag='c', protocol=None, writeback=False, *, \ + serializer=None, deserializer=None) Open a persistent dictionary. The filename specified is the base filename for the underlying database. As a side-effect, an extension may be added to the @@ -41,6 +42,21 @@ lots of shared sub-objects. The keys are ordinary strings. determine which accessed entries are mutable, nor which ones were actually mutated). + By default, :mod:`shelve` uses :func:`pickle.dumps` and :func:`pickle.loads` + for serializing and deserializing. This can be changed by supplying + *serializer* and *deserializer*, respectively. + + The *serializer* argument must be a callable which takes an object ``obj`` + and the *protocol* as inputs and returns the representation ``obj`` as a + :term:`bytes-like object`; the *protocol* value may be ignored by the + serializer. + + The *deserializer* argument must be callable which takes a serialized object + given as a :class:`bytes` object and returns the corresponding object. + + A :exc:`ShelveError` is raised if *serializer* is given but *deserializer* + is not, or vice-versa. + .. versionchanged:: 3.10 :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle protocol. @@ -48,6 +64,10 @@ lots of shared sub-objects. The keys are ordinary strings. .. versionchanged:: 3.11 Accepts :term:`path-like object` for filename. + .. versionchanged:: next + Accepts custom *serializer* and *deserializer* functions in place of + :func:`pickle.dumps` and :func:`pickle.loads`. + .. note:: Do not rely on the shelf being closed automatically; always call @@ -129,7 +149,8 @@ Restrictions explicitly. -.. class:: Shelf(dict, protocol=None, writeback=False, keyencoding='utf-8') +.. class:: Shelf(dict, protocol=None, writeback=False, \ + keyencoding='utf-8', *, serializer=None, deserializer=None) A subclass of :class:`collections.abc.MutableMapping` which stores pickled values in the *dict* object. @@ -147,6 +168,9 @@ Restrictions The *keyencoding* parameter is the encoding used to encode keys before they are used with the underlying dict. + The *serializer* and *deserializer* parameters have the same interpretation + as in :func:`~shelve.open`. + A :class:`Shelf` object can also be used as a context manager, in which case it will be automatically closed when the :keyword:`with` block ends. @@ -161,8 +185,13 @@ Restrictions :const:`pickle.DEFAULT_PROTOCOL` is now used as the default pickle protocol. + .. versionchanged:: next + Added the *serializer* and *deserializer* parameters. -.. class:: BsdDbShelf(dict, protocol=None, writeback=False, keyencoding='utf-8') + +.. class:: BsdDbShelf(dict, protocol=None, writeback=False, \ + keyencoding='utf-8', *, \ + serializer=None, deserializer=None) A subclass of :class:`Shelf` which exposes :meth:`!first`, :meth:`!next`, :meth:`!previous`, :meth:`!last` and :meth:`!set_location` methods. @@ -172,18 +201,27 @@ Restrictions modules. The *dict* object passed to the constructor must support those methods. This is generally accomplished by calling one of :func:`!bsddb.hashopen`, :func:`!bsddb.btopen` or :func:`!bsddb.rnopen`. The - optional *protocol*, *writeback*, and *keyencoding* parameters have the same - interpretation as for the :class:`Shelf` class. + optional *protocol*, *writeback*, *keyencoding*, *serializer* and *deserializer* + parameters have the same interpretation as in :func:`~shelve.open`. + + .. versionchanged:: next + Added the *serializer* and *deserializer* parameters. -.. class:: DbfilenameShelf(filename, flag='c', protocol=None, writeback=False) +.. class:: DbfilenameShelf(filename, flag='c', protocol=None, \ + writeback=False, *, serializer=None, \ + deserializer=None) A subclass of :class:`Shelf` which accepts a *filename* instead of a dict-like object. The underlying file will be opened using :func:`dbm.open`. By default, the file will be created and opened for both read and write. The - optional *flag* parameter has the same interpretation as for the :func:`.open` - function. The optional *protocol* and *writeback* parameters have the same - interpretation as for the :class:`Shelf` class. + optional *flag* parameter has the same interpretation as for the + :func:`.open` function. The optional *protocol*, *writeback*, *serializer* + and *deserializer* parameters have the same interpretation as in + :func:`~shelve.open`. + + .. versionchanged:: next + Added the *serializer* and *deserializer* parameters. .. _shelve-example: @@ -225,6 +263,20 @@ object):: d.close() # close it +Exceptions +---------- + +.. exception:: ShelveError + + Exception raised when one of the arguments *deserializer* and *serializer* + is missing in the :func:`~shelve.open`, :class:`Shelf`, :class:`BsdDbShelf` + and :class:`DbfilenameShelf`. + + The *deserializer* and *serializer* arguments must be given together. + + .. versionadded:: next + + .. seealso:: Module :mod:`dbm` diff --git a/Lib/shelve.py b/Lib/shelve.py index b53dc8b7a8ece9..1010be1e09d702 100644 --- a/Lib/shelve.py +++ b/Lib/shelve.py @@ -56,12 +56,17 @@ the persistent dictionary on disk, if feasible). """ -from pickle import DEFAULT_PROTOCOL, Pickler, Unpickler +from pickle import DEFAULT_PROTOCOL, dumps, loads from io import BytesIO import collections.abc -__all__ = ["Shelf", "BsdDbShelf", "DbfilenameShelf", "open"] +__all__ = ["ShelveError", "Shelf", "BsdDbShelf", "DbfilenameShelf", "open"] + + +class ShelveError(Exception): + pass + class _ClosedDict(collections.abc.MutableMapping): 'Marker for a closed dict. Access attempts raise a ValueError.' @@ -82,7 +87,7 @@ class Shelf(collections.abc.MutableMapping): """ def __init__(self, dict, protocol=None, writeback=False, - keyencoding="utf-8"): + keyencoding="utf-8", *, serializer=None, deserializer=None): self.dict = dict if protocol is None: protocol = DEFAULT_PROTOCOL @@ -91,6 +96,16 @@ def __init__(self, dict, protocol=None, writeback=False, self.cache = {} self.keyencoding = keyencoding + if serializer is None and deserializer is None: + self.serializer = dumps + self.deserializer = loads + elif (serializer is None) ^ (deserializer is None): + raise ShelveError("serializer and deserializer must be " + "defined together") + else: + self.serializer = serializer + self.deserializer = deserializer + def __iter__(self): for k in self.dict.keys(): yield k.decode(self.keyencoding) @@ -110,8 +125,8 @@ def __getitem__(self, key): try: value = self.cache[key] except KeyError: - f = BytesIO(self.dict[key.encode(self.keyencoding)]) - value = Unpickler(f).load() + f = self.dict[key.encode(self.keyencoding)] + value = self.deserializer(f) if self.writeback: self.cache[key] = value return value @@ -119,10 +134,8 @@ def __getitem__(self, key): def __setitem__(self, key, value): if self.writeback: self.cache[key] = value - f = BytesIO() - p = Pickler(f, self._protocol) - p.dump(value) - self.dict[key.encode(self.keyencoding)] = f.getvalue() + serialized_value = self.serializer(value, self._protocol) + self.dict[key.encode(self.keyencoding)] = serialized_value def __delitem__(self, key): del self.dict[key.encode(self.keyencoding)] @@ -191,33 +204,29 @@ class BsdDbShelf(Shelf): """ def __init__(self, dict, protocol=None, writeback=False, - keyencoding="utf-8"): - Shelf.__init__(self, dict, protocol, writeback, keyencoding) + keyencoding="utf-8", *, serializer=None, deserializer=None): + Shelf.__init__(self, dict, protocol, writeback, keyencoding, + serializer=serializer, deserializer=deserializer) def set_location(self, key): (key, value) = self.dict.set_location(key) - f = BytesIO(value) - return (key.decode(self.keyencoding), Unpickler(f).load()) + return (key.decode(self.keyencoding), self.deserializer(value)) def next(self): (key, value) = next(self.dict) - f = BytesIO(value) - return (key.decode(self.keyencoding), Unpickler(f).load()) + return (key.decode(self.keyencoding), self.deserializer(value)) def previous(self): (key, value) = self.dict.previous() - f = BytesIO(value) - return (key.decode(self.keyencoding), Unpickler(f).load()) + return (key.decode(self.keyencoding), self.deserializer(value)) def first(self): (key, value) = self.dict.first() - f = BytesIO(value) - return (key.decode(self.keyencoding), Unpickler(f).load()) + return (key.decode(self.keyencoding), self.deserializer(value)) def last(self): (key, value) = self.dict.last() - f = BytesIO(value) - return (key.decode(self.keyencoding), Unpickler(f).load()) + return (key.decode(self.keyencoding), self.deserializer(value)) class DbfilenameShelf(Shelf): @@ -227,9 +236,11 @@ class DbfilenameShelf(Shelf): See the module's __doc__ string for an overview of the interface. """ - def __init__(self, filename, flag='c', protocol=None, writeback=False): + def __init__(self, filename, flag='c', protocol=None, writeback=False, *, + serializer=None, deserializer=None): import dbm - Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback) + Shelf.__init__(self, dbm.open(filename, flag), protocol, writeback, + serializer=serializer, deserializer=deserializer) def clear(self): """Remove all items from the shelf.""" @@ -238,8 +249,8 @@ def clear(self): self.cache.clear() self.dict.clear() - -def open(filename, flag='c', protocol=None, writeback=False): +def open(filename, flag='c', protocol=None, writeback=False, *, + serializer=None, deserializer=None): """Open a persistent dictionary for reading and writing. The filename parameter is the base filename for the underlying @@ -252,4 +263,5 @@ def open(filename, flag='c', protocol=None, writeback=False): See the module's __doc__ string for an overview of the interface. """ - return DbfilenameShelf(filename, flag, protocol, writeback) + return DbfilenameShelf(filename, flag, protocol, writeback, + serializer=serializer, deserializer=deserializer) diff --git a/Lib/test/test_shelve.py b/Lib/test/test_shelve.py index 08c6562f2a273e..64609ab9dd9a62 100644 --- a/Lib/test/test_shelve.py +++ b/Lib/test/test_shelve.py @@ -1,10 +1,11 @@ +import array import unittest import dbm import shelve import pickle import os -from test.support import os_helper +from test.support import import_helper, os_helper from collections.abc import MutableMapping from test.test_dbm import dbm_iterator @@ -165,6 +166,239 @@ def test_default_protocol(self): with shelve.Shelf({}) as s: self.assertEqual(s._protocol, pickle.DEFAULT_PROTOCOL) + def test_custom_serializer_and_deserializer(self): + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol): + if isinstance(obj, (bytes, bytearray, str)): + if protocol == 5: + return obj + return type(obj).__name__ + elif isinstance(obj, array.array): + return obj.tobytes() + raise TypeError(f"Unsupported type for serialization: {type(obj)}") + + def deserializer(data): + if isinstance(data, (bytes, bytearray, str)): + return data.decode("utf-8") + elif isinstance(data, array.array): + return array.array("b", data) + raise TypeError( + f"Unsupported type for deserialization: {type(data)}" + ) + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto), shelve.open( + self.fn, + protocol=proto, + serializer=serializer, + deserializer=deserializer + ) as s: + bar = "bar" + bytes_data = b"Hello, world!" + bytearray_data = bytearray(b"\x00\x01\x02\x03\x04") + array_data = array.array("i", [1, 2, 3, 4, 5]) + + s["foo"] = bar + s["bytes_data"] = bytes_data + s["bytearray_data"] = bytearray_data + s["array_data"] = array_data + + if proto == 5: + self.assertEqual(s["foo"], str(bar)) + self.assertEqual(s["bytes_data"], "Hello, world!") + self.assertEqual( + s["bytearray_data"], bytearray_data.decode() + ) + self.assertEqual( + s["array_data"], array_data.tobytes().decode() + ) + else: + self.assertEqual(s["foo"], "str") + self.assertEqual(s["bytes_data"], "bytes") + self.assertEqual(s["bytearray_data"], "bytearray") + self.assertEqual( + s["array_data"], array_data.tobytes().decode() + ) + + def test_custom_incomplete_serializer_and_deserializer(self): + dbm_sqlite3 = import_helper.import_module("dbm.sqlite3") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + with self.assertRaises(dbm_sqlite3.error): + def serializer(obj, protocol=None): + pass + + def deserializer(data): + return data.decode("utf-8") + + with shelve.open(self.fn, serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + + def serializer(obj, protocol=None): + return type(obj).__name__.encode("utf-8") + + def deserializer(data): + pass + + with shelve.open(self.fn, serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertNotEqual(s["foo"], "bar") + self.assertIsNone(s["foo"]) + + def test_custom_serializer_and_deserializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol=None): + data = obj.__class__.__name__ + if protocol == 5: + data = str(len(data)) + return data.encode("utf-8") + + def deserializer(data): + return data.decode("utf-8") + + def type_name_len(obj): + return f"{(len(type(obj).__name__))}" + + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + with self.subTest(proto=proto), shelve.BsdDbShelf( + berkeleydb.btopen(self.fn), + protocol=proto, + serializer=serializer, + deserializer=deserializer, + ) as s: + bar = "bar" + bytes_obj = b"Hello, world!" + bytearray_obj = bytearray(b"\x00\x01\x02\x03\x04") + arr_obj = array.array("i", [1, 2, 3, 4, 5]) + + s["foo"] = bar + s["bytes_data"] = bytes_obj + s["bytearray_data"] = bytearray_obj + s["array_data"] = arr_obj + + if proto == 5: + self.assertEqual(s["foo"], type_name_len(bar)) + self.assertEqual(s["bytes_data"], type_name_len(bytes_obj)) + self.assertEqual(s["bytearray_data"], + type_name_len(bytearray_obj)) + self.assertEqual(s["array_data"], type_name_len(arr_obj)) + + k, v = s.set_location(b"foo") + self.assertEqual(k, "foo") + self.assertEqual(v, type_name_len(bar)) + + k, v = s.previous() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, type_name_len(bytes_obj)) + + k, v = s.previous() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, type_name_len(bytearray_obj)) + + k, v = s.previous() + self.assertEqual(k, "array_data") + self.assertEqual(v, type_name_len(arr_obj)) + + k, v = s.next() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, type_name_len(bytearray_obj)) + + k, v = s.next() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, type_name_len(bytes_obj)) + + k, v = s.first() + self.assertEqual(k, "array_data") + self.assertEqual(v, type_name_len(arr_obj)) + else: + k, v = s.set_location(b"foo") + self.assertEqual(k, "foo") + self.assertEqual(v, "str") + + k, v = s.previous() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, "bytes") + + k, v = s.previous() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, "bytearray") + + k, v = s.previous() + self.assertEqual(k, "array_data") + self.assertEqual(v, "array") + + k, v = s.next() + self.assertEqual(k, "bytearray_data") + self.assertEqual(v, "bytearray") + + k, v = s.next() + self.assertEqual(k, "bytes_data") + self.assertEqual(v, "bytes") + + k, v = s.first() + self.assertEqual(k, "array_data") + self.assertEqual(v, "array") + + self.assertEqual(s["foo"], "str") + self.assertEqual(s["bytes_data"], "bytes") + self.assertEqual(s["bytearray_data"], "bytearray") + self.assertEqual(s["array_data"], "array") + + def test_custom_incomplete_serializer_and_deserializer_bsd_db_shelf(self): + berkeleydb = import_helper.import_module("berkeleydb") + os.mkdir(self.dirname) + self.addCleanup(os_helper.rmtree, self.dirname) + + def serializer(obj, protocol=None): + return type(obj).__name__.encode("utf-8") + + def deserializer(data): + pass + + with shelve.BsdDbShelf(berkeleydb.btopen(self.fn), + serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertIsNone(s["foo"]) + self.assertNotEqual(s["foo"], "bar") + + def serializer(obj, protocol=None): + pass + + def deserializer(data): + return data.decode("utf-8") + + with shelve.BsdDbShelf(berkeleydb.btopen(self.fn), + serializer=serializer, + deserializer=deserializer) as s: + s["foo"] = "bar" + self.assertEqual(s["foo"], "") + self.assertNotEqual(s["foo"], "bar") + + def test_missing_custom_deserializer(self): + def serializer(obj, protocol=None): + pass + + kwargs = dict(protocol=2, writeback=False, serializer=serializer) + self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs) + self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs) + + def test_missing_custom_serializer(self): + def deserializer(data): + pass + + kwargs = dict(protocol=2, writeback=False, deserializer=deserializer) + self.assertRaises(shelve.ShelveError, shelve.Shelf, {}, **kwargs) + self.assertRaises(shelve.ShelveError, shelve.BsdDbShelf, {}, **kwargs) + class TestShelveBase: type2test = shelve.Shelf diff --git a/Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst b/Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst new file mode 100644 index 00000000000000..735249b4dae224 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2024-07-16-00-01-04.gh-issue-99631.GWD4fD.rst @@ -0,0 +1,2 @@ +The :mod:`shelve` module now accepts custom serialization +and deserialization functions.