Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,39 @@ List[...] | pa.list_(...) |
Dict[..., ...] | pa.map_(pa key_type, pa value_type) |
Enum of str | pa.dictionary(pa.int32(), pa.string()) |
Enum of int | pa.int64() |
UUID (uuid.UUID or pydantic.types.UUID*) | pa.uuid() | SEE NOTE BELOW!

Note on UUIDs: the UUID type is only supported in pyarrow 18.0 and above. However,
as of pyarrow 19.0, when pyarrow creates a table in eg `pa.Table.from_pylist(objs, schema=schema)`,
it expects bytes not a uuid.UUID type. Hence, if you are using .model_dump() to create
the data for pyarrow, you need to add a serializer on your pydantic model to convert to bytes.
This may be fixed in later versions (see [https://github.com/apache/arrow/issues/43855]).

eg (with pyarrow >= 18.0):
```py
import uuid
from typing import Annotated

import pyarrow as pa
from pydantic import BaseModel, PlainSerializer
from pydantic_to_pyarrow import get_pyarrow_schema

class ModelWithUuid(BaseModel):
uuid: Annotated[uuid.UUID, PlainSerializer(lambda x: x.bytes, return_type=bytes)]


schema = get_pyarrow_schema(ModelWithUuid)

model1 = ModelWithUuid(uuid=uuid.uuid1())
model2 = ModelWithUuid(uuid=uuid.uuid4())
data = [model1.model_dump(), model2.model_dump()]
table = pa.Table.from_pylist(data)
print(table)
#> pyarrow.Table
#> uuid: binary
#> ----
#> uuid: [[BF206AC0DA4711EF8271EF4F4B7A3587,211C4C5D94C74876AE5E32DBCCDC16C7]]
```

## Settings

Expand Down
20 changes: 19 additions & 1 deletion src/pydantic_to_pyarrow/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import types
import uuid
from decimal import Decimal
from enum import EnumMeta
from typing import Any, List, Literal, NamedTuple, Optional, Type, TypeVar, Union, cast
Expand Down Expand Up @@ -156,6 +157,18 @@ def _get_enum_type(field_type: Type[Any]) -> pa.DataType:
raise SchemaCreationError(msg)


def _get_uuid_type() -> pa.DataType:
# Different branches will execute depending on the pyarrow version
# This is tested through nox and python versions, but each one
# won't cover both branches. Hence, excluding from coverage.
if hasattr(pa, "uuid"): # pragma: no cover
return pa.uuid()
else: # pragma: no cover
msg = f"pyarrow version {pa.__version__} does not support pa.uuid() type, "
msg += "needs version 18.0 or higher"
raise SchemaCreationError(msg)


def _is_optional(field_type: Type[Any]) -> bool:
origin = get_origin(field_type)
is_python_39_union = origin is Union
Expand All @@ -167,14 +180,19 @@ def _is_optional(field_type: Type[Any]) -> bool:
return type(None) in get_args(field_type)


def _get_pyarrow_type(
# noqa: PLR0911 - ignore until a refactoring can reduce the number of
# return statements.
def _get_pyarrow_type( # noqa: PLR0911
field_type: Type[Any],
metadata: List[Any],
settings: Settings,
) -> pa.DataType:
if field_type in FIELD_MAP:
return FIELD_MAP[field_type]

if field_type is uuid.UUID:
return _get_uuid_type()

if settings.allow_losing_tz and field_type in LOSING_TZ_TYPES:
return LOSING_TZ_TYPES[field_type]

Expand Down
58 changes: 57 additions & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime
import tempfile
import uuid
from decimal import Decimal
from enum import Enum, auto
from pathlib import Path
Expand All @@ -11,8 +12,12 @@
import pytest
from annotated_types import Gt
from packaging import version
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, PlainSerializer
from pydantic.types import (
UUID1,
UUID3,
UUID4,
UUID5,
AwareDatetime,
NaiveDatetime,
PositiveInt,
Expand Down Expand Up @@ -592,6 +597,57 @@ class DictModel(BaseModel):
assert objs == [{"foo": dict(t["foo"])} for t in new_objs]


def test_uuid() -> None:
# pyarrow 18.0.0+ is required for UUID support
# Even then, pyarrow doesn't automatically convert UUIDs to bytes
# for the serialization, so we need to do that manually
# (https://github.com/apache/arrow/issues/43855)
as_bytes = PlainSerializer(lambda x: x.bytes, return_type=bytes)

class ModelWithUUID(BaseModel):
foo_0: Annotated[uuid.UUID, as_bytes] = Field(default_factory=uuid.uuid1)
foo_1: Annotated[UUID1, as_bytes] = Field(default_factory=uuid.uuid1)
foo_3: Annotated[UUID3, as_bytes] = Field(
default_factory=lambda: uuid.uuid3(uuid.NAMESPACE_DNS, "pydantic.org")
)
foo_4: Annotated[UUID4, as_bytes] = Field(default_factory=uuid.uuid4)
foo_5: Annotated[UUID5, as_bytes] = Field(
default_factory=lambda: uuid.uuid5(uuid.NAMESPACE_DNS, "pydantic.org")
)

if version.Version(pa.__version__) < version.Version("18.0.0"):
with pytest.raises(SchemaCreationError) as err:
get_pyarrow_schema(ModelWithUUID)
assert "needs version 18.0 or higher" in str(err)
else:
expected = pa.schema(
[
pa.field("foo_0", pa.uuid(), nullable=False),
pa.field("foo_1", pa.uuid(), nullable=False),
pa.field("foo_3", pa.uuid(), nullable=False),
pa.field("foo_4", pa.uuid(), nullable=False),
pa.field("foo_5", pa.uuid(), nullable=False),
]
)

actual = get_pyarrow_schema(ModelWithUUID)
assert actual == expected

objs = [
ModelWithUUID().model_dump(),
ModelWithUUID().model_dump(),
]

new_schema, new_objs = _write_pq_and_read(objs, expected)
assert new_schema == expected
# objs was created with the uuid serializer to bytes,
# but pyarrow will read the uuids into UUID objects directly
for obj in objs:
for key in obj:
obj[key] = uuid.UUID(bytes=obj[key])
assert new_objs == objs


def test_alias() -> None:
class AliasModel(BaseModel):
field1: str
Expand Down
Loading