Skip to content

Commit

Permalink
fix[python]: Fix datatype pickling (#4621)
Browse files Browse the repository at this point in the history
  • Loading branch information
OneRaynyDay committed Aug 30, 2022
1 parent c24d9b9 commit 5a6b32f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
13 changes: 13 additions & 0 deletions py-polars/polars/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,16 @@ def get_idx_type() -> type[DataType]:
return _get_idx_type()


def _custom_reconstruct(cls: type[Any], base: type[Any], state: Any) -> PolarsDataType:
if state:
obj = base.__new__(cls, state)
if base.__init__ != object.__init__:
base.__init__(obj, state)
else:
obj = object.__new__(cls)
return obj


class DataType:
"""Base class for all Polars data types."""

Expand All @@ -56,6 +66,9 @@ def __new__(cls, *args: Any, **kwargs: Any) -> PolarsDataType: # type: ignore[m
return super().__new__(cls)
return cls

def __reduce__(self) -> Any:
return (_custom_reconstruct, (type(self), object, None), self.__dict__)

@classmethod
def string_repr(cls) -> str:
return dtype_str_repr(cls)
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import inspect
import pickle

import polars as pl
from polars import datatypes
Expand Down Expand Up @@ -32,3 +33,10 @@ def test_dtype_temporal_units() -> None:

def test_get_idx_type() -> None:
assert datatypes.get_idx_type() == datatypes.UInt32


def test_dtypes_picklable() -> None:
parametric_type = pl.Datetime("ns")
singleton_type = pl.Float64
assert pickle.loads(pickle.dumps(parametric_type)) == parametric_type
assert pickle.loads(pickle.dumps(singleton_type)) == singleton_type

0 comments on commit 5a6b32f

Please sign in to comment.