Skip to content
This repository has been archived by the owner on Mar 24, 2024. It is now read-only.

Handle raw optional better #282

Merged
merged 1 commit into from
Mar 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"RedirectOutput"
],
"args": [
"tests/core/test_db.py::test_get_ordered_list"
"tests/spec/test_schema_parser.py::test_schema2json"
]
}
]
Expand Down
6 changes: 5 additions & 1 deletion openapi/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def get_server_version(self, ctx, param, value) -> None:
ctx.exit()


def open_api_cli(ctx: click.Context) -> OpenApiClient:
return ctx.obj["cli"]


@click.command("serve", short_help="Start aiohttp server.")
@click.option(
"--host", "-h", default=HOST, help=f"The interface to bind to (default to {HOST})"
Expand All @@ -135,7 +139,7 @@ def get_server_version(self, ctx, param, value) -> None:
@click.pass_context
def serve(ctx, host, port, index, reload):
"""Run the aiohttp server."""
cli: OpenApiClient = ctx.obj["cli"]
cli = open_api_cli(ctx)
cli.index = index
app = cli.get_serve_app()
access_log = logger if ctx.obj["log_level"] else None
Expand Down
7 changes: 5 additions & 2 deletions openapi/data/dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,15 @@ def dump(schema: Any, data: Any) -> Any:
return data


def dump_dataclass(schema: type, data: Optional[Union[Dict, Record]] = None) -> Dict:
def dump_dataclass(schema: Any, data: Optional[Union[Dict, Record]] = None) -> Dict:
"""Dump a dictionary of data with a given dataclass dump functions
If the data is not given, the schema object is assumed to be
an instance of a dataclass.
"""
data = asdict(schema) if data is None else data
if data is None:
data = asdict(schema)
elif isinstance(data, schema):
data = asdict(data)
cleaned = {}
fields_ = {f.name: f for f in fields(schema)}
for name, value in iter_items(data):
Expand Down
11 changes: 7 additions & 4 deletions openapi/db/commands.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import click
from sqlalchemy_utils import create_database, database_exists, drop_database

from openapi.cli import open_api_cli

from .dbmodel import CrudDB
from .migrations import Migration


def migration(ctx):
return Migration(ctx.obj["cli"].web())
def migration(ctx: click.Context) -> Migration:
return Migration(open_api_cli(ctx).web())


def get_db(ctx):
return ctx.obj["cli"].web()["db"]
def get_db(ctx: click.Context) -> CrudDB:
return open_api_cli(ctx).web()["db"]


@click.group()
Expand Down
24 changes: 18 additions & 6 deletions openapi/spec/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dataclasses import fields as get_fields
from dataclasses import is_dataclass
from enum import Enum
from typing import Any, Dict, Iterable, List, Optional, Set, Type, cast
from typing import Any, Dict, Iterable, List, Optional, Set, Type, Union, cast

from aiohttp import hdrs, web

Expand Down Expand Up @@ -107,9 +107,11 @@ def get_parameters(self, schema: Any, default_in: str = "path") -> List:
params.append(entry)
return params

def field2json(self, field: Field, validate: bool = True) -> Dict[str, str]:
def field2json(
self, field_or_type: Union[Type, Field], validate: bool = True
) -> Dict[str, dict]:
"""Convert a dataclass field to Json schema"""
field = fields.as_field(field)
field = fields.as_field(field_or_type)
meta = field.metadata
items = meta.get(fields.ITEMS)
json_property = self.get_schema_info(field.type, items=items)
Expand Down Expand Up @@ -142,11 +144,12 @@ def dataclass2json(self, schema: Any) -> Dict[str, Any]:
properties = {}
required = []
for item in get_fields(type_info.element):
if item.metadata.get(fields.REQUIRED, False):
required.append(item.name)
json_property = self.field2json(item)
field_required = json_property.pop("required", True)
if not json_property:
continue
if item.metadata.get(fields.REQUIRED, field_required):
required.append(item.name)
for name in fields.field_ops(item):
properties[name] = json_property

Expand Down Expand Up @@ -186,7 +189,16 @@ def get_schema_info(
),
}
elif type_info.is_union:
return {"oneOf": [self.get_schema_info(e) for e in type_info.element]}
required = True
one_of = []
for e in type_info.element:
if e.is_none:
required = False
else:
one_of.append(self.get_schema_info(e))
info = one_of[0] if len(one_of) == 1 else {"oneOf": one_of}
info["required"] = required
return info
elif type_info.is_dataclass:
name = self.add_schema_to_parse(type_info.element)
return {"$ref": f"{SCHEMA_BASE_REF}{name}"}
Expand Down
5 changes: 5 additions & 0 deletions openapi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ def is_complex(self) -> bool:
"""True if :attr:`.element` is either a dataclass or a union"""
return self.container is not None or self.is_union

@property
def is_none(self) -> bool:
"""True if :attr:`.element` is either a dataclass or a union"""
return self.element is type(None) # noqa: E721

@classmethod
def get(cls, value: Any) -> Optional["TypingInfo"]:
"""Create a :class:`.TypingInfo` from a typing annotation or
Expand Down
48 changes: 28 additions & 20 deletions tests/spec/test_schema_parser.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Dict, List
from typing import Dict, List, Optional

import pytest

Expand All @@ -16,19 +16,22 @@
from openapi.spec import SchemaParser


def test_get_schema_ref():
@pytest.fixture
def parser() -> SchemaParser:
return SchemaParser()


def test_get_schema_ref(parser: SchemaParser):
@dataclass
class MyClass:
str_field: str = data_field(description="String field")

parser = SchemaParser()

schema_ref = parser.get_schema_info(MyClass)
assert schema_ref == {"$ref": "#/components/schemas/MyClass"}
assert "MyClass" in parser.schemas_to_parse


def test_schema2json():
def test_schema2json(parser: SchemaParser):
@dataclass
class OtherClass:
str_field: str = data_field(description="String field")
Expand All @@ -37,6 +40,7 @@ class OtherClass:
class MyClass:
"""Test data"""

raw: str
str_field: str = data_field(
required=True, format="uuid", description="String field"
)
Expand All @@ -50,13 +54,16 @@ class MyClass:
metadata={"required": True, "description": "Ref field"}, default=None
)
list_ref_field: List[OtherClass] = data_field(description="List field")
random: Optional[str] = None

parser = SchemaParser()
schema_json = parser.schema2json(MyClass)
expected = {
"type": "object",
"description": "Test data",
"properties": {
"raw": {
"type": "string",
},
"str_field": {
"type": "string",
"format": "uuid",
Expand Down Expand Up @@ -97,26 +104,27 @@ class MyClass:
"items": {"$ref": "#/components/schemas/OtherClass"},
"description": "List field",
},
"random": {"type": "string"},
},
"required": ["str_field", "ref_field"],
"required": ["raw", "str_field", "ref_field"],
"additionalProperties": False,
}
assert schema_json == expected


def test_field2json():
parser = SchemaParser([])
str_json = parser.field2json(str)
int_json = parser.field2json(int)
float_json = parser.field2json(float)
bool_json = parser.field2json(bool)
datetime_json = parser.field2json(datetime)

assert str_json == {"type": "string"}
assert int_json == {"type": "integer", "format": "int32"}
assert float_json == {"type": "number", "format": "float"}
assert bool_json == {"type": "boolean"}
assert datetime_json == {"type": "string", "format": "date-time"}
@pytest.mark.parametrize(
"field,schema",
(
(str, {"type": "string"}),
(int, {"type": "integer", "format": "int32"}),
(float, {"type": "number", "format": "float"}),
(bool, {"type": "boolean"}),
(datetime, {"type": "string", "format": "date-time"}),
(Optional[str], {"type": "string", "required": False}),
),
)
def test_field2json(parser, field, schema):
assert parser.field2json(field) == schema


def test_field2json_format():
Expand Down