Skip to content

Commit

Permalink
Fixes #15: Discover methods for HTTPEndpoint if are not specified
Browse files Browse the repository at this point in the history
  • Loading branch information
perdy committed Mar 20, 2019
1 parent 84db0c6 commit 0429f27
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
13 changes: 6 additions & 7 deletions starlette_api/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from starlette_api import http, websockets
from starlette_api.components import Component
from starlette_api.responses import APIResponse
from starlette_api.types import Field, FieldLocation, OptBool, OptFloat, OptInt, OptStr
from starlette_api.types import Field, FieldLocation, HTTPMethod, OptBool, OptFloat, OptInt, OptStr
from starlette_api.validation import get_output_schema

__all__ = ["Route", "WebSocketRoute", "Router"]
Expand Down Expand Up @@ -53,13 +53,9 @@ def _get_fields(
body_field: typing.Dict[str, Field] = {}
output_field: typing.Dict[str, typing.Any] = {}

if hasattr(self, "methods"):
if hasattr(self, "methods") and self.methods is not None:
if inspect.isclass(self.endpoint): # HTTP endpoint
methods = (
[(m, getattr(self.endpoint, m.lower() if m != "HEAD" else "get")) for m in self.methods]
if self.methods
else []
)
methods = [(m, getattr(self.endpoint, m.lower() if m != "HEAD" else "get")) for m in self.methods]
else: # HTTP function
methods = [(m, self.endpoint) for m in self.methods] if self.methods else []
else: # Websocket
Expand Down Expand Up @@ -141,6 +137,9 @@ def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args
if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
self.app = self.endpoint_wrapper(endpoint)

if self.methods is None:
self.methods = [m for m in HTTPMethod.__members__.keys() if hasattr(self, m.lower())]

self.query_fields, self.path_fields, self.body_field, self.output_field = self._get_fields(router)

def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp:
Expand Down
21 changes: 11 additions & 10 deletions starlette_api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def get_endpoints(
path=path,
method=method.lower(),
func=route.endpoint,
query_fields=route.query_fields.get(method),
path_fields=route.path_fields.get(method),
query_fields=route.query_fields.get(method, {}),
path_fields=route.path_fields.get(method, {}),
body_field=route.body_field.get(method),
output_field=route.output_field.get(method),
)
Expand All @@ -97,8 +97,8 @@ def get_endpoints(
path=path,
method=method.lower(),
func=func,
query_fields=route.query_fields.get(method.upper()),
path_fields=route.path_fields.get(method.upper()),
query_fields=route.query_fields.get(method.upper(), {}),
path_fields=route.path_fields.get(method.upper(), {}),
body_field=route.body_field.get(method.upper()),
output_field=route.output_field.get(method.upper()),
)
Expand All @@ -108,13 +108,13 @@ def get_endpoints(

return endpoints_info

def get_endpoint_parameters_schema(self, endpoint: EndpointInfo, schema: typing.Dict) -> typing.List[typing.Dict]:
def _add_endpoint_parameters(self, endpoint: EndpointInfo, schema: typing.Dict):
schema["parameters"] = [
self.openapi.field2parameter(field.schema, name=field.name, default_in=field.location.name)
for field in itertools.chain(endpoint.query_fields.values(), endpoint.path_fields.values())
]

def get_endpoint_body_schema(self, endpoint: EndpointInfo, schema: typing.Dict):
def _add_endpoint_body(self, endpoint: EndpointInfo, schema: typing.Dict):
component_schema = (
endpoint.body_field.schema
if inspect.isclass(endpoint.body_field.schema)
Expand All @@ -132,7 +132,7 @@ def get_endpoint_body_schema(self, endpoint: EndpointInfo, schema: typing.Dict):
"schema",
)

def get_endpoint_response_schema(self, endpoint: EndpointInfo, schema: typing.Dict):
def _add_endpoint_response(self, endpoint: EndpointInfo, schema: typing.Dict):
component_schema = (
endpoint.output_field if inspect.isclass(endpoint.output_field) else endpoint.output_field.__class__
)
Expand All @@ -153,18 +153,19 @@ def get_endpoint_schema(self, endpoint: EndpointInfo) -> typing.Dict[str, typing
schema = self.parse_docstring(endpoint.func)

# Query and Path parameters
self.get_endpoint_parameters_schema(endpoint, schema)
if endpoint.query_fields or endpoint.path_fields:
self._add_endpoint_parameters(endpoint, schema)

# Body
if endpoint.body_field:
self.get_endpoint_body_schema(endpoint, schema)
self._add_endpoint_body(endpoint, schema)

# Response
if endpoint.output_field and (
(inspect.isclass(endpoint.output_field) and issubclass(endpoint.output_field, marshmallow.Schema))
or isinstance(endpoint.output_field, marshmallow.Schema)
):
self.get_endpoint_response_schema(endpoint, schema)
self._add_endpoint_response(endpoint, schema)

return schema

Expand Down
3 changes: 3 additions & 0 deletions starlette_api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ class ResourceMethodMeta(typing.NamedTuple):
methods: typing.List[str] = ["GET"]
name: str = None
kwargs: typing.Dict[str, typing.Any] = {}


HTTPMethod = enum.Enum("HTTPMethod", ["GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"])
10 changes: 8 additions & 2 deletions tests/test_schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,13 +132,13 @@ def test_schema_path_params(self, app):

def test_schema_body_params(self, app):
schema = app.schema["paths"]["/body-param/"]["post"]
parameters = schema.get("parameters", {})
parameters = schema.get("parameters")
response = schema.get("responses", {}).get(200, {})
body = schema.get("requestBody", {})

assert schema["description"] == "Body param."
assert response == {"description": "Param."}
assert parameters == []
assert parameters is None
assert body == {
"content": {"application/json": {"schema": {"type": "object", "properties": {"name": {"type": "string"}}}}}
}
Expand All @@ -155,9 +155,11 @@ def test_schema_output_schema(self, app):

def test_schema_output_schema_many(self, app):
schema = app.schema["paths"]["/many-custom-component/"]["get"]
parameters = schema.get("parameters")
response = schema.get("responses", {}).get(200, {})

assert schema["description"] == "Many custom component."
assert parameters is None
assert response == {
"description": "Components.",
"content": {
Expand All @@ -167,19 +169,23 @@ def test_schema_output_schema_many(self, app):

def test_schema_output_schema_using_endpoint(self, app):
schema = app.schema["paths"]["/endpoint/"]["get"]
parameters = schema.get("parameters")
response = schema.get("responses", {}).get(200, {})

assert schema["description"] == "Custom component."
assert parameters is None
assert response == {
"description": "Component.",
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}},
}

def test_schema_output_schema_using_mount(self, app):
schema = app.schema["paths"]["/mount/custom-component/"]["get"]
parameters = schema.get("parameters")
response = schema.get("responses", {}).get(200, {})

assert schema["description"] == "Custom component."
assert parameters is None
assert response == {
"description": "Component.",
"content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}},
Expand Down

0 comments on commit 0429f27

Please sign in to comment.