diff --git a/starlette_api/routing.py b/starlette_api/routing.py index 5747c536..7ae949a5 100644 --- a/starlette_api/routing.py +++ b/starlette_api/routing.py @@ -12,7 +12,7 @@ from starlette_api import http, websockets from starlette_api.components import Component from starlette_api.responses import APIResponse -from starlette_api.types import Field, FieldLocation, OptBool, OptFloat, OptInt, OptStr +from starlette_api.types import Field, FieldLocation, HTTPMethod, OptBool, OptFloat, OptInt, OptStr from starlette_api.validation import get_output_schema __all__ = ["Route", "WebSocketRoute", "Router"] @@ -53,13 +53,9 @@ def _get_fields( body_field: typing.Dict[str, Field] = {} output_field: typing.Dict[str, typing.Any] = {} - if hasattr(self, "methods"): + if hasattr(self, "methods") and self.methods is not None: if inspect.isclass(self.endpoint): # HTTP endpoint - methods = ( - [(m, getattr(self.endpoint, m.lower() if m != "HEAD" else "get")) for m in self.methods] - if self.methods - else [] - ) + methods = [(m, getattr(self.endpoint, m.lower() if m != "HEAD" else "get")) for m in self.methods] else: # HTTP function methods = [(m, self.endpoint) for m in self.methods] if self.methods else [] else: # Websocket @@ -141,6 +137,9 @@ def __init__(self, path: str, endpoint: typing.Callable, router: "Router", *args if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): self.app = self.endpoint_wrapper(endpoint) + if self.methods is None: + self.methods = [m for m in HTTPMethod.__members__.keys() if hasattr(self, m.lower())] + self.query_fields, self.path_fields, self.body_field, self.output_field = self._get_fields(router) def endpoint_wrapper(self, endpoint: typing.Callable) -> ASGIApp: diff --git a/starlette_api/schemas.py b/starlette_api/schemas.py index ec976488..3ec0f8b4 100644 --- a/starlette_api/schemas.py +++ b/starlette_api/schemas.py @@ -80,8 +80,8 @@ def get_endpoints( path=path, method=method.lower(), func=route.endpoint, - query_fields=route.query_fields.get(method), - path_fields=route.path_fields.get(method), + query_fields=route.query_fields.get(method, {}), + path_fields=route.path_fields.get(method, {}), body_field=route.body_field.get(method), output_field=route.output_field.get(method), ) @@ -97,8 +97,8 @@ def get_endpoints( path=path, method=method.lower(), func=func, - query_fields=route.query_fields.get(method.upper()), - path_fields=route.path_fields.get(method.upper()), + query_fields=route.query_fields.get(method.upper(), {}), + path_fields=route.path_fields.get(method.upper(), {}), body_field=route.body_field.get(method.upper()), output_field=route.output_field.get(method.upper()), ) @@ -108,13 +108,13 @@ def get_endpoints( return endpoints_info - def get_endpoint_parameters_schema(self, endpoint: EndpointInfo, schema: typing.Dict) -> typing.List[typing.Dict]: + def _add_endpoint_parameters(self, endpoint: EndpointInfo, schema: typing.Dict): schema["parameters"] = [ self.openapi.field2parameter(field.schema, name=field.name, default_in=field.location.name) for field in itertools.chain(endpoint.query_fields.values(), endpoint.path_fields.values()) ] - def get_endpoint_body_schema(self, endpoint: EndpointInfo, schema: typing.Dict): + def _add_endpoint_body(self, endpoint: EndpointInfo, schema: typing.Dict): component_schema = ( endpoint.body_field.schema if inspect.isclass(endpoint.body_field.schema) @@ -132,7 +132,7 @@ def get_endpoint_body_schema(self, endpoint: EndpointInfo, schema: typing.Dict): "schema", ) - def get_endpoint_response_schema(self, endpoint: EndpointInfo, schema: typing.Dict): + def _add_endpoint_response(self, endpoint: EndpointInfo, schema: typing.Dict): component_schema = ( endpoint.output_field if inspect.isclass(endpoint.output_field) else endpoint.output_field.__class__ ) @@ -153,18 +153,19 @@ def get_endpoint_schema(self, endpoint: EndpointInfo) -> typing.Dict[str, typing schema = self.parse_docstring(endpoint.func) # Query and Path parameters - self.get_endpoint_parameters_schema(endpoint, schema) + if endpoint.query_fields or endpoint.path_fields: + self._add_endpoint_parameters(endpoint, schema) # Body if endpoint.body_field: - self.get_endpoint_body_schema(endpoint, schema) + self._add_endpoint_body(endpoint, schema) # Response if endpoint.output_field and ( (inspect.isclass(endpoint.output_field) and issubclass(endpoint.output_field, marshmallow.Schema)) or isinstance(endpoint.output_field, marshmallow.Schema) ): - self.get_endpoint_response_schema(endpoint, schema) + self._add_endpoint_response(endpoint, schema) return schema diff --git a/starlette_api/types.py b/starlette_api/types.py index 40d178e1..83089b81 100644 --- a/starlette_api/types.py +++ b/starlette_api/types.py @@ -91,3 +91,6 @@ class ResourceMethodMeta(typing.NamedTuple): methods: typing.List[str] = ["GET"] name: str = None kwargs: typing.Dict[str, typing.Any] = {} + + +HTTPMethod = enum.Enum("HTTPMethod", ["GET", "HEAD", "POST", "PUT", "DELETE", "CONNECT", "OPTIONS", "TRACE", "PATCH"]) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 191e1ab4..48cc21a4 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -132,13 +132,13 @@ def test_schema_path_params(self, app): def test_schema_body_params(self, app): schema = app.schema["paths"]["/body-param/"]["post"] - parameters = schema.get("parameters", {}) + parameters = schema.get("parameters") response = schema.get("responses", {}).get(200, {}) body = schema.get("requestBody", {}) assert schema["description"] == "Body param." assert response == {"description": "Param."} - assert parameters == [] + assert parameters is None assert body == { "content": {"application/json": {"schema": {"type": "object", "properties": {"name": {"type": "string"}}}}} } @@ -155,9 +155,11 @@ def test_schema_output_schema(self, app): def test_schema_output_schema_many(self, app): schema = app.schema["paths"]["/many-custom-component/"]["get"] + parameters = schema.get("parameters") response = schema.get("responses", {}).get(200, {}) assert schema["description"] == "Many custom component." + assert parameters is None assert response == { "description": "Components.", "content": { @@ -167,9 +169,11 @@ def test_schema_output_schema_many(self, app): def test_schema_output_schema_using_endpoint(self, app): schema = app.schema["paths"]["/endpoint/"]["get"] + parameters = schema.get("parameters") response = schema.get("responses", {}).get(200, {}) assert schema["description"] == "Custom component." + assert parameters is None assert response == { "description": "Component.", "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}}, @@ -177,9 +181,11 @@ def test_schema_output_schema_using_endpoint(self, app): def test_schema_output_schema_using_mount(self, app): schema = app.schema["paths"]["/mount/custom-component/"]["get"] + parameters = schema.get("parameters") response = schema.get("responses", {}).get(200, {}) assert schema["description"] == "Custom component." + assert parameters is None assert response == { "description": "Component.", "content": {"application/json": {"schema": {"$ref": "#/components/schemas/Puppy"}}},