Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support django model type annotations #974

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
26 changes: 26 additions & 0 deletions docs/docs/guides/response/django-pydantic.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,29 @@ also you can define just a few optional fields instead of all:
```python
fields_optional = ['description']
```

### Using type annotations in models

If you define type annotations on your model, they will be taken into account in the generated schema:

```python hl_lines="6 7 18 19"
class TaskMetadata(TypedDict):
foo: str
bar: int

class Task(models.Model):
status: Literal['todo', 'done'] = models.CharField(max_length=10)
metadata: TaskMetadata = models.JSONField()

class TaskSchema(ModelSchema):
class Meta:
model = Task
fields = "__all__"

# Will create schema like this:
#
# class TaskSchema(Schema):
# id: int
# status: Literal['todo', 'done']
# metadata: TaskMetadata
```
26 changes: 19 additions & 7 deletions ninja/orm/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pydantic import create_model as create_pydantic_model

from ninja.errors import ConfigError
from ninja.orm.fields import get_schema_field
from ninja.orm.fields import ModelField, get_schema_field
from ninja.schema import Schema

# MAYBE:
Expand Down Expand Up @@ -55,7 +55,7 @@ def create_schema(
if key in self.schemas:
return self.schemas[key]

model_fields_list = self._selected_model_fields(model, fields, exclude)
model_fields_list = list(self._selected_model_fields(model, fields, exclude))
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vitalik I added a cast to list here to fix this bug:
If optional_fields was equal to __all__, optional_fields = [f.name for f in model_fields_list] a few lines below was exhausting the model_fields_list iterator, and made for fld in model_fields_list iterate over an empty list, producing an empty Pydantic model.

if optional_fields:
if optional_fields == "__all__":
optional_fields = [f.name for f in model_fields_list]
Expand All @@ -71,8 +71,6 @@ def create_schema(

if custom_fields:
for fld_name, python_type, field_info in custom_fields:
# if not isinstance(field_info, FieldInfo):
# field_info = Field(field_info)
definitions[fld_name] = (python_type, field_info)

if name in self.schema_names:
Expand Down Expand Up @@ -133,7 +131,7 @@ def _selected_model_fields(
model: Type[Model],
fields: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
) -> Iterator[DjangoField]:
) -> Iterator[ModelField]:
"Returns iterator for model fields based on `exclude` or `fields` arguments"
all_fields = {f.name: f for f in self._model_fields(model)}

Expand All @@ -155,13 +153,27 @@ def _selected_model_fields(
if f.name not in exclude:
yield f

def _model_fields(self, model: Type[Model]) -> Iterator[DjangoField]:
def _model_fields(self, model: Type[Model]) -> Iterator[ModelField]:
"returns iterator with all the fields that can be part of schema"
type_annotations = self._get_type_annotations(model)
for fld in model._meta.get_fields():
if isinstance(fld, (ManyToOneRel, ManyToManyRel)):
# skipping relations
continue
yield cast(DjangoField, fld)
field = cast(DjangoField, fld)
yield ModelField(field, type_annotations.get(field.name))

def _get_type_annotations(self, model_class: Any) -> Dict[str, Any]:
# Take inherited classes annotations into account
classes: List[Any] = model_class.mro()
# Reverse the list so child classes annotations have precedence
classes.reverse()

annotations = {}
for cls in classes:
cls_annotations = getattr(cls, "__annotations__", {})
annotations.update(cls_annotations)
return annotations


factory = SchemaFactory()
Expand Down
73 changes: 52 additions & 21 deletions ninja/orm/fields.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
import datetime
from decimal import Decimal
from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, no_type_check
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
no_type_check,
)
from uuid import UUID

from django.db.models import ManyToManyField
Expand All @@ -24,6 +35,18 @@ def title_if_lower(s: str) -> str:
return s


class ModelField:
def __init__(
self, django_field: DjangoField, type_annotation: Optional[Any]
) -> None:
self.django_field = django_field
self.type_annotation = type_annotation

@property
def name(self):
return self.django_field.name


class AnyObject:
@classmethod
def __get_pydantic_core_schema__(cls, source: Any, handler: Callable) -> Any:
Expand Down Expand Up @@ -104,7 +127,7 @@ def _validate(cls, v: Any, _):

@no_type_check
def get_schema_field(
field: DjangoField, *, depth: int = 0, optional: bool = False
field: ModelField, *, depth: int = 0, optional: bool = False
) -> Tuple:
"Returns pydantic field from django's model field"
alias = None
Expand All @@ -115,53 +138,61 @@ def get_schema_field(
max_length = None
python_type = None

if field.is_relation:
if field.django_field.is_relation:
if depth > 0:
return get_related_field_schema(field, depth=depth)
return get_related_field_schema(field.django_field, depth=depth)

internal_type = field.related_model._meta.pk.get_internal_type()
internal_type = field.django_field.related_model._meta.pk.get_internal_type()

if not field.concrete and field.auto_created or field.null:
if (
not field.django_field.concrete
and field.django_field.auto_created
or field.django_field.null
):
default = None

alias = getattr(field, "get_attname", None) and field.get_attname()
alias = (
getattr(field.django_field, "get_attname", None)
and field.django_field.get_attname()
)

pk_type = TYPES.get(internal_type, int)
if field.one_to_many or field.many_to_many:
if field.django_field.one_to_many or field.django_field.many_to_many:
m2m_type = create_m2m_link_type(pk_type)
python_type = List[m2m_type] # type: ignore
else:
python_type = pk_type

else:
_f_name, _f_path, _f_pos, field_options = field.deconstruct()
_f_name, _f_path, _f_pos, field_options = field.django_field.deconstruct()
blank = field_options.get("blank", False)
null = field_options.get("null", False)
max_length = field_options.get("max_length")

internal_type = field.get_internal_type()
python_type = TYPES[internal_type]
if field.type_annotation and not isinstance(field.type_annotation, str):
python_type = field.type_annotation
else:
internal_type = field.django_field.get_internal_type()
python_type = TYPES[internal_type]

if field.has_default():
if callable(field.default):
default_factory = field.default
if field.django_field.has_default():
if callable(field.django_field.default):
default_factory = field.django_field.default
else:
default = field.default
elif field.primary_key or blank or null:
default = field.django_field.default
elif field.django_field.primary_key or blank or null:
default = None

if default_factory:
default = PydanticUndefined

if optional:
elif optional:
default = None

if default is None:
default = None
python_type = Union[python_type, None] # aka Optional in 3.7+

description = field.help_text or None
title = title_if_lower(field.verbose_name)
description = field.django_field.help_text or None
title = title_if_lower(field.django_field.verbose_name)

return (
python_type,
Expand Down
70 changes: 70 additions & 0 deletions tests/test_orm_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from django.contrib.postgres import fields as ps_fields
from django.db import models
from django.db.models import Manager
from typing_extensions import Literal, TypedDict

from ninja.errors import ConfigError
from ninja.orm import create_schema
Expand Down Expand Up @@ -575,3 +576,72 @@ class Meta:
SomeReqFieldModel, optional_fields=["some_field", "other_field", "optional"]
)
assert Schema.json_schema().get("required") is None


def test_type_annotations():
class TestModelConfiguration(TypedDict):
region: Literal[0, 1]
index: int

class TestModel(models.Model):
status: Literal["todo", "done"] = models.CharField() # type: ignore
configuration: TestModelConfiguration = models.JSONField() # type: ignore

class Meta:
app_label = "tests"

Schema = create_schema(TestModel)

assert Schema.json_schema() == {
"$defs": {
"TestModelConfiguration": {
"properties": {
"region": {"enum": [0, 1], "title": "Region", "type": "integer"},
"index": {"title": "Index", "type": "integer"},
},
"required": ["region", "index"],
"title": "TestModelConfiguration",
"type": "object",
}
},
"properties": {
"id": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "ID"},
"status": {"enum": ["todo", "done"], "title": "Status", "type": "string"},
"configuration": {
"allOf": [{"$ref": "#/$defs/TestModelConfiguration"}],
"title": "Configuration",
},
},
"required": ["status", "configuration"],
"title": "TestModel",
"type": "object",
}


def test_type_annotations_inherited():
class AParentModel(models.Model):
rank: Literal[0, 1] = models.PositiveIntegerField() # type: ignore
status: Literal["todo", "wip", "done"] = models.CharField() # type: ignore

class Meta:
app_label = "tests"

class AChildModel(AParentModel):
status: Literal["todo", "done"] # Narrow type of parent field

class Meta:
app_label = "tests"

Schema = create_schema(AChildModel)

assert Schema.json_schema() == {
"properties": {
"id": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "ID"},
"rank": {"enum": [0, 1], "title": "Rank", "type": "integer"},
"status": {"enum": ["todo", "done"], "title": "Status", "type": "string"},
"aparentmodel_ptr_id": {"title": "Aparentmodel Ptr", "type": "integer"},
},
"required": ["rank", "status", "aparentmodel_ptr_id"],
"title": "AChildModel",
"type": "object",
}