Skip to content

Commit

Permalink
938 Improve create_pydantic_model for multidimensional arrays (#939)
Browse files Browse the repository at this point in the history
* handle multidimensional arrays in pydantic

* make sure the inner type is also correct

for example, if it's `Array(Array(Varchar()))`, the inner `Varchar` should be `constr`.

* ignore type warning
  • Loading branch information
dantownsend committed Mar 4, 2024
1 parent 2322a07 commit 0e2ec8a
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 12 deletions.
51 changes: 39 additions & 12 deletions piccolo/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,43 @@ def validate_columns(
)


def get_array_value_type(
column: Array, inner: t.Optional[t.Type] = None
) -> t.Type:
"""
Gets the correct type for an ``Array`` column (which might be
multidimensional).
"""
if isinstance(column.base_column, Array):
inner_type = get_array_value_type(column.base_column, inner=inner)
else:
inner_type = get_pydantic_value_type(column.base_column)

return t.List[inner_type] # type: ignore


def get_pydantic_value_type(column: Column) -> t.Type:
"""
Map the Piccolo ``Column`` to a Pydantic type.
"""
value_type: t.Type

if isinstance(column, (Decimal, Numeric)):
value_type = pydantic.condecimal(
max_digits=column.precision, decimal_places=column.scale
)
elif isinstance(column, Email):
value_type = pydantic.EmailStr # type: ignore
elif isinstance(column, Varchar):
value_type = pydantic.constr(max_length=column.length)
elif isinstance(column, Array):
value_type = get_array_value_type(column=column)
else:
value_type = column.value_type

return value_type


def create_pydantic_model(
table: t.Type[Table],
nested: t.Union[bool, t.Tuple[ForeignKey, ...]] = False,
Expand Down Expand Up @@ -211,17 +248,7 @@ def create_pydantic_model(
#######################################################################
# Work out the column type

if isinstance(column, (Decimal, Numeric)):
value_type: t.Type = pydantic.condecimal(
max_digits=column.precision, decimal_places=column.scale
)
elif isinstance(column, Email):
value_type = pydantic.EmailStr
elif isinstance(column, Varchar):
value_type = pydantic.constr(max_length=column.length)
elif isinstance(column, Array):
value_type = t.List[column.base_column.value_type] # type: ignore
elif isinstance(column, (JSON, JSONB)):
if isinstance(column, (JSON, JSONB)):
if deserialize_json:
value_type = pydantic.Json
else:
Expand All @@ -235,7 +262,7 @@ def create_pydantic_model(
validator # type: ignore
)
else:
value_type = column.value_type
value_type = get_pydantic_value_type(column=column)

_type = t.Optional[value_type] if is_optional else value_type

Expand Down
26 changes: 26 additions & 0 deletions tests/utils/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,32 @@ class Band(Table):
"string",
)

def test_multidimensional_array(self):
"""
Make sure that multidimensional arrays have the correct type.
"""

class Band(Table):
members = Array(Array(Varchar(length=255)), required=True)

pydantic_model = create_pydantic_model(table=Band)

self.assertEqual(
pydantic_model.model_fields["members"].annotation,
t.List[t.List[pydantic.constr(max_length=255)]],
)

# Should not raise a validation error:
pydantic_model(
members=[
["Alice", "Bob", "Francis"],
["Alan", "Georgia", "Sue"],
]
)

with self.assertRaises(ValueError):
pydantic_model(members=["Bob"])


class TestForeignKeyColumn(TestCase):
def test_target_column(self):
Expand Down

0 comments on commit 0e2ec8a

Please sign in to comment.