Skip to content

Commit

Permalink
fix before hook parse issue, test operation logs
Browse files Browse the repository at this point in the history
  • Loading branch information
voidZXL committed Mar 8, 2024
1 parent a66e5ff commit 632e485
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 51 deletions.
6 changes: 3 additions & 3 deletions tests/test_api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def hello(self):
return 'world'

@api.before(SubAPI)
def assign_var(self, x_test_id: int = request.HeaderParam('X-Test-Id')):
def assign_var(self, x_test_id: int = request.HeaderParam('X-Test-Id', default='null')):
test_var.setter(self.request, x_test_id)


Expand Down Expand Up @@ -245,8 +245,8 @@ def headers_kwargs(
self,
q: int = request.QueryParam(gt=0),
data: DataSchema = request.Body,
x_test_id: int = request.HeaderParam('X-Test-ID'),
) -> Tuple[int, int, DataSchema]:
x_test_id: int = request.HeaderParam('X-Test-ID', default='null'),
) -> Tuple[Union[int, str], int, DataSchema]:
return x_test_id, q, data

# class FileData(utype.Schema):
Expand Down
19 changes: 18 additions & 1 deletion tests/test_api/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,12 +279,21 @@ def test_api_features(self):
'test0',
200,
),
(
"get",
"@parent/sub/test0",
{},
None,
{}, # test default for before
'test0',
200,
),
(
"get",
"@parent/sub/test1",
{},
None,
{'X-Test-ID': 3},
{'X-Test-ID': '3'},
{'test1': 3},
200,
),
Expand Down Expand Up @@ -398,6 +407,14 @@ def test_api_features(self):
{"X-test-ID": "11"},
[11, 3, {"title": "test", "views": 0}],
200,
),(
"post",
"headers_kwargs",
{"q": '6'},
{"title": "test"},
{}, # test default value in request Param
['null', 6, {"title": "test", "views": 0}],
200,
),
]
for method, path, query, body, headers, result, status in requests:
Expand Down
7 changes: 5 additions & 2 deletions utilmeta/core/api/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ class BaseEndpoint(PluginTarget):
parser_cls = FunctionParser
wrapper_cls = RequestContextWrapper

PARSE_PARAMS = False
PARSE_RESULT = True

def __init__(self, f: Callable, *,
method: str,
plugins: list = None,
Expand All @@ -74,8 +77,8 @@ def __init__(self, f: Callable, *,
self.wrapper = self.wrapper_cls(self.parser_cls.apply_for(f))
self.executor = self.parser.wrap(
eager_parse=self.eager,
parse_params=False,
parse_result=True
parse_params=self.PARSE_PARAMS,
parse_result=self.PARSE_RESULT
)

def iter_plugins(self):
Expand Down
3 changes: 2 additions & 1 deletion utilmeta/core/api/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ async def __call__(self, *args, **kwargs):
class BeforeHook(Hook):
hook_type = utils.EndpointAttr.before_hook
wrapper_cls = RequestContextWrapper
parse_params = True
# parse_params = False
# already pared for request

@classmethod
def apply_for(cls, func: Callable) -> 'BeforeHook':
Expand Down
32 changes: 16 additions & 16 deletions utilmeta/core/auth/plugins/require.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,31 +45,31 @@ async def enter_endpoint(self, api: 'API', *args, **kwargs):
if self.functions:
await self.validate_functions(api)
if self.scopes:
await self.validate_scopes(api)
self.validate_scopes(api)

def validate_scopes(self, api: 'API'):
scopes = self.scopes_var.get(api.request)
scopes = self.scopes_var.getter(api.request)
if not set(scopes or []).issuperset(self.scopes):
raise exceptions.PermissionDenied(
'insufficient scope',
scopes=scopes,
required_scopes=self.scopes,
scope=scopes,
required_scope=self.scopes,
name=self.name
)

@awaitable
async def validate_scopes(self, api: 'API'):
scopes = await self.scopes_var.get(api.request)
if not set(scopes or []).issuperset(self.scopes):
raise exceptions.PermissionDenied(
'insufficient scope',
scopes=scopes,
required_scopes=self.scopes,
name=self.name
)
# @awaitable
# async def validate_scopes(self, api: 'API'):
# scopes = await self.scopes_var.getter(api.request)
# if not set(scopes or []).issuperset(self.scopes):
# raise exceptions.PermissionDenied(
# 'insufficient scope',
# scopes=scopes,
# required_scopes=self.scopes,
# name=self.name
# )

def validate_functions(self, api: 'API'):
user = self.user_var.get(api.request)
user = self.user_var.getter(api.request)
if user is None:
pass
for func in self.functions:
Expand All @@ -83,7 +83,7 @@ def validate_functions(self, api: 'API'):

@awaitable(validate_functions)
async def validate_functions(self, api: 'API'):
user = await self.user_var.get(api.request)
user = await self.user_var.getter(api.request)
if user is None:
pass
for func in self.functions:
Expand Down
96 changes: 72 additions & 24 deletions utilmeta/ops/api.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from utilmeta.core import api, orm, request, response, auth
from .schema import SupervisorData
from .schema import SupervisorData, ServiceLogSchema, ServiceLogBase, AccessTokenSchema
from utilmeta.utils import exceptions
from .models import Supervisor, AccessToken
from .models import Supervisor, AccessToken, ServiceLog
from . import __spec_version__
from .config import Operations
from .key import decode_token
Expand All @@ -12,6 +12,18 @@
from utilmeta.core.api.specs.openapi import OpenAPI


class SupervisorObject(orm.Schema[Supervisor]):
id: int
service: str
node_id: str
url: Optional[str] = None
public_key: Optional[str] = None
ops_api: str
ident: str
base_url: Optional[str] = None
local: bool = False


# excludes = var.RequestContextVar('_excludes', cached=True)
# params = var.RequestContextVar('_params', cached=True)
supervisor_var = var.RequestContextVar('_ops.supervisor', cached=True)
Expand All @@ -25,14 +37,14 @@ class opsRequire(auth.Require):
def validate_scopes(self, api_inst: api.API):
if config.disabled_scope and config.disabled_scope.intersection(self.scopes):
raise exceptions.PermissionDenied(f'Operation: {self.scopes} denied by config')
scopes = self.scopes_var.get(api_inst.request)
scopes = self.scopes_var.getter(api_inst.request)
if '*' in scopes:
return
return super().validate_scopes(api_inst)


class QueryAPI(api.API):
supervisor: Supervisor = supervisor_var
supervisor: SupervisorObject = supervisor_var

# scope: data.view:[TABLE_IDENT]
@opsRequire('data.query')
Expand All @@ -53,11 +65,31 @@ def delete(self):


class LogAPI(api.API):
supervisor: Supervisor = supervisor_var
supervisor: SupervisorObject = supervisor_var

@opsRequire('log.view')
def get(self):
pass
def get(self, id: int) -> ServiceLogSchema:
try:
return ServiceLogSchema.init(id)
except orm.EmptyQueryset:
raise exceptions.NotFound

class LogQuery(orm.Query[ServiceLog]):
offset: int = orm.Offset()
page: int = orm.Page()
rows: int = orm.Limit(default=20, le=100, alias_from=['limit'])

@opsRequire('log.view')
@api.get
def service(self, query: LogQuery) -> List[ServiceLogBase]:
return ServiceLogBase.serialize(
query.get_queryset(
ServiceLog.objects.filter(
service=self.supervisor.service,
node_id=self.supervisor.node_id
).order_by('-time')
)
)

@opsRequire('log.delete')
def delete(self):
Expand All @@ -66,11 +98,11 @@ def delete(self):

@opsRequire('metrics.view')
class MetricsAPI(api.API):
supervisor: Supervisor = supervisor_var
supervisor: SupervisorObject = supervisor_var


class TokenAPI(api.API):
supervisor: Supervisor = supervisor_var
supervisor: SupervisorObject = supervisor_var

def get(self):
pass
Expand All @@ -87,20 +119,27 @@ def revoke(self, id_list: List[str] = request.Body) -> int:
for token_id in set(id_list).difference({exists}):
AccessToken.objects.create(
token_id=token_id,
issuer=supervisor_var,
issuer_id=self.supervisor.id,
expiry_time=self.request.time + timedelta(days=1),
revoked=True
)

if exists:
AccessToken.objects.filter(
token_id__in=id_list,
issuer=self.supervisor
issuer_id=self.supervisor.id
).update(revoked=True)

return len(exists)


@api.CORS(
allow_origin='*',
allow_headers=[
'authorization',
'x-node-id'
]
)
class OperationsAPI(api.API):
__external__ = True

Expand Down Expand Up @@ -169,10 +208,13 @@ def handle_token(self, node_id: str = request.HeaderParam('X-Node-ID', default=N
raise exceptions.BadRequest('Node ID required', state='node_required')
validated = False
from utilmeta import service
for supervisor in Supervisor.objects.filter(
service=service.name,
node_id=node_id,
disabled=False,
for supervisor in SupervisorObject.serialize(
Supervisor.objects.filter(
service=service.name,
node_id=node_id,
disabled=False,
public_key__isnull=False
)
):
data = decode_token(token, public_key=supervisor.public_key)
if not data:
Expand Down Expand Up @@ -200,7 +242,7 @@ def handle_token(self, node_id: str = request.HeaderParam('X-Node-ID', default=N
scopes = scope.split(' ') if ' ' in scope else scope.split(',')
scope_names = []
resources = []
for name in scope_names:
for name in scopes:
if ':' in name:
name, resource = name.split(':')
resources.append(resource)
Expand All @@ -213,24 +255,29 @@ def handle_token(self, node_id: str = request.HeaderParam('X-Node-ID', default=N
if not token_id:
raise exceptions.BadRequest('Invalid token: id required', state='token_expired')

token_obj: AccessToken = AccessToken.objects.filter(
token_id=token_id,
issuer=supervisor
).first()
try:
token_obj = AccessTokenSchema.init(
AccessToken.objects.filter(
token_id=token_id,
issuer_id=supervisor.id
)
)
except orm.EmptyQueryset:
token_obj = None

if token_obj:
if token_obj.revoked:
# force revoked
# e.g. the subject permissions has changed after the token issued
raise exceptions.BadRequest('Invalid token: revoked', state='token_expired')
token_obj.last_activity = self.request.time
token_obj.used_times = models.F('used_times') + 1
token_obj.save(update_fields=['last_activity', 'used_times'])
token_obj.used_times += 1
token_obj.save()
else:
try:
token_obj = AccessToken.objects.create(
token_obj = AccessTokenSchema(
token_id=token_id,
issuer=supervisor,
issuer_id=supervisor.id,
issued_at=datetime.fromtimestamp(data.get('iat')),
expiry_time=datetime.fromtimestamp(expires),
subject=data.get('sub'),
Expand All @@ -239,6 +286,7 @@ def handle_token(self, node_id: str = request.HeaderParam('X-Node-ID', default=N
ip=str(self.request.ip_address),
scope=scopes
)
token_obj.save()
except utils.IntegrityError:
raise exceptions.BadRequest('Invalid token: id duplicated', state='token_expired')

Expand Down
14 changes: 10 additions & 4 deletions utilmeta/ops/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,22 +371,28 @@ def generate_log(self, response: Response):
request_headers = {}
response_headers = {}
if level >= self.STORE_HEADERS_LEVEL:
request_headers = self.parse_values(request.headers)
response_headers = self.parse_values(response.headers)
request_headers = self.parse_values(dict(request.headers))
response_headers = self.parse_values(dict(response.headers))

operation_names = var.operation_names.getter(request) or []
endpoint_ident = '_'.join(operation_names)
endpoint_ref = var.endpoint_ref.getter(request) or None
endpoint = _endpoints_map.get(endpoint_ident) if endpoint_ident else None
access_token = access_token_var.getter(request)

try:
level_str = LOG_LEVELS[level]
except IndexError:
level_str = LogLevel.DEBUG

return ServiceLog(
service=self.service.name,
instance=_instance,
version=_version,
node_id=getattr(_supervisor, 'node_id', None),
supervisor=_supervisor,
access_token=access_token,
level=str(level),
access_token_id=getattr(access_token, 'id', None),
level=level_str,
volatile=volatile,
time=request.time,
duration=duration,
Expand Down
Loading

0 comments on commit 632e485

Please sign in to comment.