Skip to content

Commit

Permalink
fix error response assignment, fix to-one schema query
Browse files Browse the repository at this point in the history
  • Loading branch information
voidZXL committed Apr 12, 2024
1 parent 35daf88 commit 22594bb
Show file tree
Hide file tree
Showing 25 changed files with 268 additions and 111 deletions.
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,10 @@ def service_process(service: UtilMeta):
else:
server.terminate()
return service_process

# TODO
# currently I am not able to write subprocess tests that can be measured in pytest-cov
# so the current workaround is:
# 1. use server process to test the REAL-WORLD-LIVE-CASE of apis (but not counting to coverage)
# 2. use server thread to make up the coverage (redundant though)
# working on a better solution (eg. make the subprocess executions cover-able)
13 changes: 11 additions & 2 deletions tests/server/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utilmeta.core import api, response, request, file
from utype.types import *
import utype
from utilmeta.utils import exceptions
from utilmeta.utils import exceptions, Error

test_var = request.var.RequestContextVar('test-id', cached=True)

Expand Down Expand Up @@ -70,6 +70,15 @@ def hello(self):
def assign_var(self, x_test_id: int = request.HeaderParam('X-Test-Id', default='null')):
test_var.setter(self.request, x_test_id)

@api.get
def error(self, msg: str):
raise exceptions.UpgradeRequired(msg)

@api.handle(error, exceptions.UpgradeRequired)
def handle_error_1(self, e: Error):
# test error hook: use the API's response
return e.exception.message


@api.route('the/api')
class TheAPI(api.API):
Expand Down Expand Up @@ -210,7 +219,7 @@ def update(
# print('UPDATE:', image, image.size, image.read())
# print(test_cookie)
return {
'image': image.read(),
'image': image.read().decode(),
'cookie': test_cookie,
'name': name,
'desc': desc,
Expand Down
17 changes: 16 additions & 1 deletion tests/server/app/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


__all__ = ["UserSchema", "ArticleSchema", "CommentSchema",
"ContentSchema", 'UserBase', 'UserQuery', 'ArticleQuery']
"ContentSchema", 'UserBase', 'UserQuery', 'ArticleQuery', 'ArticleBase', 'ContentBase']


class UserBase(orm.Schema[User]):
Expand Down Expand Up @@ -125,6 +125,18 @@ class article_schema(cls):
return article_schema


class ArticleBase(orm.Schema[Article]):
id: int
title: str
slug: str


class ContentBase(orm.Schema[BaseContent]):
id: int
content: str
article: Optional[ArticleBase]


class UserSchema(UserBase):
@classmethod
def get_top_articles(cls, *pks):
Expand Down Expand Up @@ -156,6 +168,9 @@ async def get_top_articles(cls, *pks):
articles_num: int = orm.Field(exp.Count("contents", filter=exp.Q(contents__type="article")))
combined_num: int = orm.Field(exp.Count("contents") * exp.Count("followers"))

articles: List[ArticleBase] = orm.Field('contents__article')
# test multi+fk

# @property
# def total_views(self) -> int:
# return sum([article.views for article in self.articles])
Expand Down
9 changes: 9 additions & 0 deletions tests/test_api/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,15 @@ def get_requests(backend: str = None):
{'resp': 5},
200,
),
(
"get",
"@parent/error",
{"msg": "test_message"},
None,
{},
{'test': 'test_message', 'message': ''},
200,
),
# ----------------------
(
"post",
Expand Down
29 changes: 28 additions & 1 deletion tests/test_orm/test_schema_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def test_serialize_users(self, service):
assert res[1]["@views"] == 13
assert len(res[1].top_articles) == 2
assert res[1].top_articles[0].views == 10 # -views
assert res[1].articles_num == 3
assert len(res[1].articles) == 3

# -- test one with no article
sup = UserSchema.init(5)
assert len(sup.articles) == 0
assert sup.articles_num == 0

def test_scope_and_excludes(self):
from app.schema import UserSchema, UserQuery
Expand Down Expand Up @@ -62,7 +69,7 @@ def test_scope_and_excludes(self):
assert res1[0].sum_views == 103

def test_init_articles(self, service):
from app.schema import ArticleSchema
from app.schema import ArticleSchema, ContentBase
article = ArticleSchema.init(1)
assert article.id == article.pk == 1
assert article.liked_bys_num == 3
Expand All @@ -74,6 +81,11 @@ def test_init_articles(self, service):
assert len(article.comments) == 2
assert article.author_tag['name'] == 'bob'

# test sub relation
content = ContentBase.init(1)
assert content.id == 1
assert content.article.id == 1

def test_related_qs(self):
from app.schema import UserBase, ArticleSchema, UserQuery
from app.models import Article, Follow, User
Expand Down Expand Up @@ -144,6 +156,21 @@ async def test_async_init_users(self):
assert user.top_articles[0].author_tag["name"] == "alice"
assert user.top_articles[0].views == 103

# --------------
bob = await UserSchema.ainit(
User.objects.filter(
username='bob',
)
)
assert bob.pk == 2
assert len(bob.articles) == 3
assert bob.articles_num == 3

# ---
sup = UserSchema.init(5)
assert len(sup.articles) == 0
assert sup.articles_num == 0

@pytest.mark.asyncio
async def test_async_init_users_with_sync_query(self):
# for django, it requires bind_service=True in @awaitable
Expand Down
37 changes: 21 additions & 16 deletions utilmeta/core/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,9 @@ def _init_properties(self):
context=context
)
except ParseError as e:
if self.request.is_options:
# ignore parse error for OPTIONS request
continue
raise exc.BadRequest(str(e), detail=e.get_detail()) from e
self.__dict__[name] = value

Expand All @@ -447,7 +450,7 @@ def _handle_error(self, error: Error, error_hooks: dict):
# if nothing return, it implies that follow the api default error flow
if result is None:
raise error.throw()
return process_response(self, result)
return result

@awaitable(_handle_error)
async def _handle_error(self, error: Error, error_hooks: dict):
Expand All @@ -461,7 +464,7 @@ async def _handle_error(self, error: Error, error_hooks: dict):
# if nothing return, it implies that follow the api default error flow
if result is None:
raise error.throw()
return await process_response(self, result)
return result

def _resolve(self) -> APIRoute:
method_routes: Dict[str, APIRoute] = {}
Expand Down Expand Up @@ -501,14 +504,16 @@ def __call__(self):
with self._resolve() as route:
error_hooks = route.error_hooks
result = route(self)
if isinstance(result, Response):
result.request = self.request
elif Response.is_cls(getattr(self.__class__, 'response', None)):
result = self.response(result, request=self.request)
response = process_response(self, result)
except Exception as e:
response = self._handle_error(Error(e), error_hooks)
return response
result = self._handle_error(Error(e), error_hooks)

if isinstance(result, Response):
if not result.request:
result.request = self.request
elif Response.is_cls(getattr(self.__class__, 'response', None)):
result = self.response(result, request=self.request)

return process_response(self, result)

@awaitable(__call__)
async def __call__(self):
Expand All @@ -518,14 +523,14 @@ async def __call__(self):
with self._resolve() as route:
error_hooks = route.error_hooks
result = await route(self)
if isinstance(result, Response):
result.request = self.request
elif Response.is_cls(getattr(self.__class__, 'response', None)):
result = self.response(result, request=self.request)
response = await process_response(self, result)
except Exception as e:
response = await self._handle_error(Error(e), error_hooks)
return response
result = await self._handle_error(Error(e), error_hooks)
if isinstance(result, Response):
if not result.request:
result.request = self.request
elif Response.is_cls(getattr(self.__class__, 'response', None)):
result = self.response(result, request=self.request)
return await process_response(self, result)

def options(self):
return Response(headers={
Expand Down
44 changes: 40 additions & 4 deletions utilmeta/core/api/endpoint.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from utilmeta import utils
from utilmeta.utils import exceptions as exc
from typing import Callable, Union, Type, TYPE_CHECKING
from typing import Callable, Union, Type, List, TYPE_CHECKING
from utilmeta.utils.plugin import PluginTarget, PluginEvent
from utilmeta.utils.error import Error
from utilmeta.utils.context import ContextWrapper, Property
from utype.parser.base import BaseParser
from utype.parser.func import FunctionParser
from utype.parser.field import ParserField
from utype.parser.rule import LogicalType
import inspect
from ..request import Request, var
from ..request.properties import QueryParam, PathParam
Expand Down Expand Up @@ -56,7 +57,10 @@ class BaseEndpoint(PluginTarget):

PATH_REGEX = utils.PATH_REGEX
PARSE_PARAMS = False
PARSE_RESULT = True
# params is already parsed by the request parser
PARSE_RESULT = False
# result will be parsed in the end of endpoint.serve
STRICT_RESULT = False

def __init__(self, f: Callable, *,
method: str,
Expand Down Expand Up @@ -87,6 +91,22 @@ def __init__(self, f: Callable, *,
parse_params=self.PARSE_PARAMS,
parse_result=self.PARSE_RESULT
)
self.response_types: List[Type[Response]] = self.parse_responses(self.parser.return_type)

@classmethod
def parse_responses(cls, return_type):
def is_response(r):
return inspect.isclass(r) and issubclass(r, Response)

if is_response(return_type):
return [return_type]
elif isinstance(return_type, LogicalType):
values = []
for origin in return_type.resolve_origins():
if is_response(origin):
values.append(origin)
return values
return []

def iter_plugins(self):
for cls, plugin in self._plugins.items():
Expand Down Expand Up @@ -271,11 +291,26 @@ def ref(self) -> str:
return f'{self.module_name}.{self.f.__name__}'
return self.f.__name__

def make_response(self, response, request, error=None):
if not self.response_types:
return response
if isinstance(response, Response):
return response
for i, resp_type in enumerate(self.response_types):
try:
return resp_type(response, request=request, error=error, strict=self.STRICT_RESULT)
except Exception as e:
if i == len(self.response_types) - 1:
raise e
continue
return response

def serve(self, api: 'API'):
# ---
var.endpoint_ref.setter(api.request, self.ref)
# ---
retry_index = 0
err = None
while True:
try:
api.request.adaptor.update_context(
Expand Down Expand Up @@ -307,14 +342,15 @@ def serve(self, api: 'API'):
break
retry_index += 1
exit_endpoint(self, api)
return response
return self.make_response(response, request=api.request, error=err)

@utils.awaitable(serve)
async def serve(self, api: 'API'):
# ---
var.endpoint_ref.setter(api.request, self.ref)
# ---
retry_index = 0
err = None
while True:
try:
api.request.adaptor.update_context(
Expand Down Expand Up @@ -346,7 +382,7 @@ async def serve(self, api: 'API'):
break
retry_index += 1
await exit_endpoint(self, api)
return response
return self.make_response(response, request=api.request, error=err)

def parse_request(self, request: Request):
try:
Expand Down
16 changes: 0 additions & 16 deletions utilmeta/core/cli/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,22 +55,6 @@ def __init__(self, f: Callable, *,
)
self.client = client
self.path_args = self.PATH_REGEX.findall(self.route)
self.response_types: List[Type[Response]] = self.parse_responses(self.parser.return_type)

@classmethod
def parse_responses(cls, return_type):
def is_response(r):
return inspect.isclass(r) and issubclass(r, Response)

if is_response(return_type):
return [return_type]
elif isinstance(return_type, LogicalType):
values = []
for origin in return_type.resolve_origins():
if is_response(origin):
values.append(origin)
return values
return []

@property
def ref(self) -> str:
Expand Down
4 changes: 4 additions & 0 deletions utilmeta/core/orm/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,10 @@ def is_many(self):
def is_combined(self):
raise NotImplementedError

@property
def multi_relations(self):
raise NotImplementedError

@classmethod
def get_exp_field(cls, exp) -> Optional[str]:
raise NotImplementedError
Expand Down
Loading

0 comments on commit 22594bb

Please sign in to comment.