Skip to content

Commit

Permalink
✨ Discovering nested schemas (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Sep 19, 2023
1 parent 5734826 commit 0af3b64
Show file tree
Hide file tree
Showing 11 changed files with 832 additions and 162 deletions.
42 changes: 30 additions & 12 deletions flama/schemas/_libs/marshmallow/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from apispec.ext.marshmallow import MarshmallowPlugin, resolve_schema_cls

from flama.injection import Parameter
from flama.schemas._libs.marshmallow.fields import MAPPING
from flama.schemas._libs.marshmallow.fields import MAPPING, MAPPING_TYPES
from flama.schemas.adapter import Adapter
from flama.schemas.exceptions import SchemaGenerationError, SchemaValidationError
from flama.types import JSONSchema
Expand Down Expand Up @@ -35,7 +35,7 @@ def build_field(
required: bool = True,
default: t.Any = None,
multiple: bool = False,
**kwargs
**kwargs,
) -> Field:
field_args = {
"required": required,
Expand Down Expand Up @@ -80,6 +80,10 @@ def dump(self, schema: t.Union[t.Type[Schema], Schema], value: t.Dict[str, t.Any
except Exception as exc:
raise SchemaValidationError(errors=str(exc))

def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str:
s = self.unique_schema(schema)
return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}"

def to_json_schema(self, schema: t.Union[t.Type[Schema], t.Type[Field], Schema, Field]) -> JSONSchema:
json_schema: t.Dict[str, t.Any]
try:
Expand Down Expand Up @@ -115,18 +119,32 @@ def unique_schema(self, schema: t.Union[Schema, t.Type[Schema]]) -> t.Type[Schem

return schema

def is_schema(
self, obj: t.Any
) -> t.TypeGuard[ # type: ignore # PORT: Remove this comment when stop supporting 3.9
t.Union[Schema, t.Type[Schema]]
]:
def _get_field_type(self, field: Field) -> t.Union[Schema, t.Type]:
if isinstance(field, marshmallow.fields.Nested):
return field.schema

if isinstance(field, marshmallow.fields.List):
return self._get_field_type(field.inner) # type: ignore

if isinstance(field, marshmallow.fields.Dict):
return self._get_field_type(field.value_field) # type: ignore

try:
return MAPPING_TYPES[field.__class__]
except KeyError:
return None

def schema_fields(
self, schema: t.Union[Schema, t.Type[Schema]]
) -> t.Dict[str, t.Tuple[t.Union[t.Type, Schema], Field]]:
return {
name: (self._get_field_type(field), field) for name, field in self._schema_instance(schema).fields.items()
}

def is_schema(self, obj: t.Any) -> t.TypeGuard[t.Union[Schema, t.Type[Schema]]]: # type: ignore
return isinstance(obj, Schema) or (inspect.isclass(obj) and issubclass(obj, Schema))

def is_field(
self, obj: t.Any
) -> t.TypeGuard[ # type: ignore # PORT: Remove this comment when stop supporting 3.9
t.Union[Field, t.Type[Field]]
]:
def is_field(self, obj: t.Any) -> t.TypeGuard[t.Union[Field, t.Type[Field]]]: # type: ignore
return isinstance(obj, Field) or (inspect.isclass(obj) and issubclass(obj, Field))

def _schema_instance(self, schema: t.Union[t.Type[Schema], Schema]) -> Schema:
Expand Down
7 changes: 4 additions & 3 deletions flama/schemas/_libs/marshmallow/fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# ruff: noqa
import datetime
import typing
import typing as t
import uuid

import marshmallow.fields
from marshmallow.fields import *

MAPPING: typing.Dict[typing.Optional[typing.Type], typing.Type[marshmallow.fields.Field]] = {
MAPPING: t.Dict[t.Union[t.Type, None], t.Type[Field]] = {
None: Field,
int: Integer,
float: Float,
Expand All @@ -19,3 +18,5 @@
datetime.datetime: DateTime,
datetime.time: Time,
}

MAPPING_TYPES = {v: k for k, v in MAPPING.items()}
29 changes: 29 additions & 0 deletions flama/schemas/_libs/pydantic/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ def dump(self, schema: t.Union[Schema, t.Type[Schema]], value: t.Dict[str, t.Any

return self.validate(schema_cls, value)

def name(self, schema: t.Union[Schema, t.Type[Schema]]) -> str:
s = self.unique_schema(schema)
return s.__qualname__ if s.__module__ == "builtins" else f"{s.__module__}.{s.__qualname__}"

def to_json_schema(self, schema: t.Union[Schema, t.Type[Schema], Field]) -> JSONSchema:
try:
if self.is_schema(schema):
Expand All @@ -116,6 +120,31 @@ def to_json_schema(self, schema: t.Union[Schema, t.Type[Schema], Field]) -> JSON
def unique_schema(self, schema: t.Union[Schema, t.Type[Schema]]) -> t.Type[Schema]:
return schema.__class__ if isinstance(schema, Schema) else schema

def _get_field_type(
self, field: Field
) -> t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]]:
if not self.is_field(field):
return field

if t.get_origin(field.annotation) == list:
return self._get_field_type(t.get_args(field.annotation)[0])

if t.get_origin(field.annotation) == dict:
return self._get_field_type(t.get_args(field.annotation)[1])

return field.annotation

def schema_fields(
self, schema: t.Union[Schema, t.Type[Schema]]
) -> t.Dict[
str,
t.Tuple[
t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]],
Field,
],
]:
return {name: (self._get_field_type(field), field) for name, field in schema.model_fields.items()}

def is_schema(
self, obj: t.Any
) -> t.TypeGuard[t.Type[Schema]]: # type: ignore # PORT: Remove this comment when stop supporting 3.9
Expand Down
46 changes: 43 additions & 3 deletions flama/schemas/_libs/typesystem/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import typesystem

from flama.injection import Parameter
from flama.schemas._libs.typesystem.fields import MAPPING
from flama.schemas._libs.typesystem.fields import MAPPING, MAPPING_TYPES
from flama.schemas.adapter import Adapter
from flama.schemas.exceptions import SchemaGenerationError, SchemaValidationError
from flama.types import JSONSchema
Expand All @@ -30,7 +30,7 @@ def build_field(
required: bool = True,
default: t.Any = None,
multiple: bool = False,
**kwargs
**kwargs,
) -> Field:
if required is False and default is not Parameter.empty:
kwargs["default"] = default
Expand All @@ -44,7 +44,7 @@ def build_field(
if self.is_schema(type_)
else MAPPING[type_]()
),
**kwargs
**kwargs,
)

return MAPPING[type_](**kwargs)
Expand Down Expand Up @@ -82,6 +82,13 @@ def _dump(self, value: t.Any) -> t.Any:

return value

@t.no_type_check
def name(self, schema: Schema) -> str:
if not schema.title:
raise ValueError(f"Schema '{schema}' needs to define title attribute")

return schema.title if schema.__module__ == "builtins" else f"{schema.__module__}.{schema.title}"

@t.no_type_check
def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema:
try:
Expand All @@ -100,6 +107,39 @@ def to_json_schema(self, schema: t.Union[Schema, Field]) -> JSONSchema:
def unique_schema(self, schema: Schema) -> Schema:
return schema

def _get_field_type(
self, field: Field
) -> t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]]:
if isinstance(field, typesystem.Reference):
return field.target

if isinstance(field, typesystem.Array):
return (
[self._get_field_type(x) for x in field.items]
if isinstance(field.items, (list, tuple, set))
else self._get_field_type(field.items)
)

if isinstance(field, typesystem.Object):
return {k: self._get_field_type(v) for k, v in field.properties.items()}

try:
return MAPPING_TYPES[field.__class__]
except KeyError:
return None

@t.no_type_check
def schema_fields(
self, schema: Schema
) -> t.Dict[
str,
t.Tuple[
t.Union[t.Union[Schema, t.Type], t.List[t.Union[Schema, t.Type]], t.Dict[str, t.Union[Schema, t.Type]]],
Field,
],
]:
return {name: (self._get_field_type(field), field) for name, field in schema.fields.items()}

@t.no_type_check
def is_schema(
self, obj: t.Any
Expand Down
6 changes: 4 additions & 2 deletions flama/schemas/_libs/typesystem/fields.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# ruff: noqa
import datetime
import typing
import typing as t
import uuid

from typesystem.fields import *
from typesystem.schemas import Reference

MAPPING: typing.Dict[typing.Any, typing.Type[Field]] = {
MAPPING: t.Dict[t.Union[t.Type, None], t.Type[Field]] = {
None: Field,
int: Integer,
float: Float,
Expand All @@ -19,3 +19,5 @@
datetime.datetime: DateTime,
datetime.time: Time,
}

MAPPING_TYPES = {v: k for k, v in MAPPING.items()}
18 changes: 18 additions & 0 deletions flama/schemas/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ def load(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]], value: t.Dict[str,
def dump(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]], value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
...

@abc.abstractmethod
def name(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]]) -> str:
...

@abc.abstractmethod
def to_json_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema], _T_Field]) -> JSONSchema:
...
Expand All @@ -82,6 +86,20 @@ def to_json_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema], _T_Field]
def unique_schema(self, schema: t.Union[_T_Schema, t.Type[_T_Schema]]) -> t.Union[_T_Schema, t.Type[_T_Schema]]:
...

@abc.abstractmethod
def schema_fields(
self, schema: t.Union[_T_Schema, t.Type[_T_Schema]]
) -> t.Dict[
str,
t.Tuple[
t.Union[
t.Union[_T_Schema, t.Type], t.List[t.Union[_T_Schema, t.Type]], t.Dict[str, t.Union[_T_Schema, t.Type]]
],
_T_Field,
],
]:
...

@abc.abstractmethod
def is_schema(
self, obj: t.Any
Expand Down
34 changes: 31 additions & 3 deletions flama/schemas/data_structures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import sys
import typing as t

import flama.types
from flama import schemas, types
from flama.injection.resolver import Parameter as InjectionParameter

Expand All @@ -15,6 +14,9 @@
__all__ = ["Field", "Schema", "Parameter", "Parameters"]


UNKNOWN = t.TypeVar("UNKNOWN")


class ParameterLocation(enum.Enum):
query = enum.auto()
path = enum.auto()
Expand Down Expand Up @@ -81,7 +83,7 @@ def is_http_valid_type(cls, type_: t.Type) -> bool:
)

@property
def json_schema(self) -> flama.types.JSONSchema:
def json_schema(self) -> types.JSONSchema:
return schemas.adapter.to_json_schema(self.field)


Expand Down Expand Up @@ -121,13 +123,39 @@ def is_schema(cls, obj: t.Any) -> bool:
return schemas.adapter.is_schema(obj)

@property
def json_schema(self) -> t.Dict[str, t.Any]:
def name(self) -> str:
return schemas.adapter.name(self.schema)

@property
def json_schema(self) -> types.JSONSchema:
return schemas.adapter.to_json_schema(self.schema)

@property
def unique_schema(self) -> t.Any:
return schemas.adapter.unique_schema(self.schema)

@property
def fields(self) -> t.Dict[str, t.Tuple[t.Any, t.Any]]:
return schemas.adapter.schema_fields(self.unique_schema)

def nested_schemas(self, schema: t.Any = UNKNOWN) -> t.List[t.Any]:
if schema == UNKNOWN:
return self.nested_schemas(self)

if schemas.adapter.is_schema(schema):
return [schema]

if isinstance(schema, (list, tuple, set)):
return [x for field in schema for x in self.nested_schemas(field)]

if isinstance(schema, dict):
return [x for field in schema.values() for x in self.nested_schemas(field)]

if isinstance(schema, Schema):
return [x for field_type, _ in schema.fields.values() for x in self.nested_schemas(field_type)]

return []

@t.overload
def validate(self, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]:
...
Expand Down
Loading

0 comments on commit 0af3b64

Please sign in to comment.