diff --git a/docs/docs/guides/response/django-pydantic.md b/docs/docs/guides/response/django-pydantic.md index fc425af4..b3364566 100644 --- a/docs/docs/guides/response/django-pydantic.md +++ b/docs/docs/guides/response/django-pydantic.md @@ -106,3 +106,29 @@ also you can define just a few optional fields instead of all: ```python fields_optional = ['description'] ``` + +### Using type annotations in models + +If you define type annotations on your model, they will be taken into account in the generated schema: + +```python hl_lines="6 7 18 19" +class TaskMetadata(TypedDict): + foo: str + bar: int + +class Task(models.Model): + status: Literal['todo', 'done'] = models.CharField(max_length=10) + metadata: TaskMetadata = models.JSONField() + +class TaskSchema(ModelSchema): + class Meta: + model = Task + fields = "__all__" + +# Will create schema like this: +# +# class TaskSchema(Schema): +# id: int +# status: Literal['todo', 'done'] +# metadata: TaskMetadata +``` diff --git a/ninja/orm/factory.py b/ninja/orm/factory.py index 18b839f9..a2d97d56 100644 --- a/ninja/orm/factory.py +++ b/ninja/orm/factory.py @@ -6,7 +6,7 @@ from pydantic import create_model as create_pydantic_model from ninja.errors import ConfigError -from ninja.orm.fields import get_schema_field +from ninja.orm.fields import ModelField, get_schema_field from ninja.schema import Schema # MAYBE: @@ -55,7 +55,7 @@ def create_schema( if key in self.schemas: return self.schemas[key] - model_fields_list = self._selected_model_fields(model, fields, exclude) + model_fields_list = list(self._selected_model_fields(model, fields, exclude)) if optional_fields: if optional_fields == "__all__": optional_fields = [f.name for f in model_fields_list] @@ -71,8 +71,6 @@ def create_schema( if custom_fields: for fld_name, python_type, field_info in custom_fields: - # if not isinstance(field_info, FieldInfo): - # field_info = Field(field_info) definitions[fld_name] = (python_type, field_info) if name in self.schema_names: @@ -133,7 +131,7 @@ def _selected_model_fields( model: Type[Model], fields: Optional[List[str]] = None, exclude: Optional[List[str]] = None, - ) -> Iterator[DjangoField]: + ) -> Iterator[ModelField]: "Returns iterator for model fields based on `exclude` or `fields` arguments" all_fields = {f.name: f for f in self._model_fields(model)} @@ -155,13 +153,27 @@ def _selected_model_fields( if f.name not in exclude: yield f - def _model_fields(self, model: Type[Model]) -> Iterator[DjangoField]: + def _model_fields(self, model: Type[Model]) -> Iterator[ModelField]: "returns iterator with all the fields that can be part of schema" + type_annotations = self._get_type_annotations(model) for fld in model._meta.get_fields(): if isinstance(fld, (ManyToOneRel, ManyToManyRel)): # skipping relations continue - yield cast(DjangoField, fld) + field = cast(DjangoField, fld) + yield ModelField(field, type_annotations.get(field.name)) + + def _get_type_annotations(self, model_class: Any) -> Dict[str, Any]: + # Take inherited classes annotations into account + classes: List[Any] = model_class.mro() + # Reverse the list so child classes annotations have precedence + classes.reverse() + + annotations = {} + for cls in classes: + cls_annotations = getattr(cls, "__annotations__", {}) + annotations.update(cls_annotations) + return annotations factory = SchemaFactory() diff --git a/ninja/orm/fields.py b/ninja/orm/fields.py index f180c417..83166905 100644 --- a/ninja/orm/fields.py +++ b/ninja/orm/fields.py @@ -1,6 +1,17 @@ import datetime from decimal import Decimal -from typing import Any, Callable, Dict, List, Tuple, Type, TypeVar, Union, no_type_check +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, + no_type_check, +) from uuid import UUID from django.db.models import ManyToManyField @@ -24,6 +35,18 @@ def title_if_lower(s: str) -> str: return s +class ModelField: + def __init__( + self, django_field: DjangoField, type_annotation: Optional[Any] + ) -> None: + self.django_field = django_field + self.type_annotation = type_annotation + + @property + def name(self): + return self.django_field.name + + class AnyObject: @classmethod def __get_pydantic_core_schema__(cls, source: Any, handler: Callable) -> Any: @@ -104,7 +127,7 @@ def _validate(cls, v: Any, _): @no_type_check def get_schema_field( - field: DjangoField, *, depth: int = 0, optional: bool = False + field: ModelField, *, depth: int = 0, optional: bool = False ) -> Tuple: "Returns pydantic field from django's model field" alias = None @@ -115,53 +138,61 @@ def get_schema_field( max_length = None python_type = None - if field.is_relation: + if field.django_field.is_relation: if depth > 0: - return get_related_field_schema(field, depth=depth) + return get_related_field_schema(field.django_field, depth=depth) - internal_type = field.related_model._meta.pk.get_internal_type() + internal_type = field.django_field.related_model._meta.pk.get_internal_type() - if not field.concrete and field.auto_created or field.null: + if ( + not field.django_field.concrete + and field.django_field.auto_created + or field.django_field.null + ): default = None - alias = getattr(field, "get_attname", None) and field.get_attname() + alias = ( + getattr(field.django_field, "get_attname", None) + and field.django_field.get_attname() + ) pk_type = TYPES.get(internal_type, int) - if field.one_to_many or field.many_to_many: + if field.django_field.one_to_many or field.django_field.many_to_many: m2m_type = create_m2m_link_type(pk_type) python_type = List[m2m_type] # type: ignore else: python_type = pk_type else: - _f_name, _f_path, _f_pos, field_options = field.deconstruct() + _f_name, _f_path, _f_pos, field_options = field.django_field.deconstruct() blank = field_options.get("blank", False) null = field_options.get("null", False) max_length = field_options.get("max_length") - internal_type = field.get_internal_type() - python_type = TYPES[internal_type] + if field.type_annotation and not isinstance(field.type_annotation, str): + python_type = field.type_annotation + else: + internal_type = field.django_field.get_internal_type() + python_type = TYPES[internal_type] - if field.has_default(): - if callable(field.default): - default_factory = field.default + if field.django_field.has_default(): + if callable(field.django_field.default): + default_factory = field.django_field.default else: - default = field.default - elif field.primary_key or blank or null: + default = field.django_field.default + elif field.django_field.primary_key or blank or null: default = None if default_factory: default = PydanticUndefined - - if optional: + elif optional: default = None if default is None: - default = None python_type = Union[python_type, None] # aka Optional in 3.7+ - description = field.help_text or None - title = title_if_lower(field.verbose_name) + description = field.django_field.help_text or None + title = title_if_lower(field.django_field.verbose_name) return ( python_type, diff --git a/tests/test_orm_schemas.py b/tests/test_orm_schemas.py index 07df9fa1..d1f68bd2 100644 --- a/tests/test_orm_schemas.py +++ b/tests/test_orm_schemas.py @@ -5,6 +5,7 @@ from django.contrib.postgres import fields as ps_fields from django.db import models from django.db.models import Manager +from typing_extensions import Literal, TypedDict from ninja.errors import ConfigError from ninja.orm import create_schema @@ -575,3 +576,72 @@ class Meta: SomeReqFieldModel, optional_fields=["some_field", "other_field", "optional"] ) assert Schema.json_schema().get("required") is None + + +def test_type_annotations(): + class TestModelConfiguration(TypedDict): + region: Literal[0, 1] + index: int + + class TestModel(models.Model): + status: Literal["todo", "done"] = models.CharField() # type: ignore + configuration: TestModelConfiguration = models.JSONField() # type: ignore + + class Meta: + app_label = "tests" + + Schema = create_schema(TestModel) + + assert Schema.json_schema() == { + "$defs": { + "TestModelConfiguration": { + "properties": { + "region": {"enum": [0, 1], "title": "Region", "type": "integer"}, + "index": {"title": "Index", "type": "integer"}, + }, + "required": ["region", "index"], + "title": "TestModelConfiguration", + "type": "object", + } + }, + "properties": { + "id": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "ID"}, + "status": {"enum": ["todo", "done"], "title": "Status", "type": "string"}, + "configuration": { + "allOf": [{"$ref": "#/$defs/TestModelConfiguration"}], + "title": "Configuration", + }, + }, + "required": ["status", "configuration"], + "title": "TestModel", + "type": "object", + } + + +def test_type_annotations_inherited(): + class AParentModel(models.Model): + rank: Literal[0, 1] = models.PositiveIntegerField() # type: ignore + status: Literal["todo", "wip", "done"] = models.CharField() # type: ignore + + class Meta: + app_label = "tests" + + class AChildModel(AParentModel): + status: Literal["todo", "done"] # Narrow type of parent field + + class Meta: + app_label = "tests" + + Schema = create_schema(AChildModel) + + assert Schema.json_schema() == { + "properties": { + "id": {"anyOf": [{"type": "integer"}, {"type": "null"}], "title": "ID"}, + "rank": {"enum": [0, 1], "title": "Rank", "type": "integer"}, + "status": {"enum": ["todo", "done"], "title": "Status", "type": "string"}, + "aparentmodel_ptr_id": {"title": "Aparentmodel Ptr", "type": "integer"}, + }, + "required": ["rank", "status", "aparentmodel_ptr_id"], + "title": "AChildModel", + "type": "object", + }