Skip to content

Commit

Permalink
refactor arg_checker
Browse files Browse the repository at this point in the history
  • Loading branch information
Cuizi7 committed Oct 19, 2018
1 parent 9f7a2a3 commit 5494ca4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 11 deletions.
4 changes: 2 additions & 2 deletions rqalpha/api/api_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -852,8 +852,8 @@ def get_dividend(order_book_id, start_date, *args, **kwargs):
@ExecutionContext.enforce_phase(EXECUTION_PHASE.ON_BAR,
EXECUTION_PHASE.ON_TICK,
EXECUTION_PHASE.SCHEDULED)
@apply_rules(verify_that('series_name').is_instance_of(str),
verify_that('value').is_number())
@apply_rules(verify_that('series_name', pre_check=True).is_instance_of(str),
verify_that('value', pre_check=True).is_number())
def plot(series_name, value):
"""
Add a point to custom series.
Expand Down
36 changes: 27 additions & 9 deletions rqalpha/utils/arg_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@


class ArgumentChecker(object):
def __init__(self, arg_name):
def __init__(self, arg_name, pre_check):
self._arg_name = arg_name
self._pre_check = pre_check
self._rules = []

def is_instance_of(self, types):
Expand Down Expand Up @@ -364,15 +365,34 @@ def verify(self, func_name, value):
def arg_name(self):
return self._arg_name

@property
def pre_check(self):
return self._pre_check


def verify_that(arg_name, pre_check=False):
return ArgumentChecker(arg_name, pre_check)


def verify_that(arg_name):
return ArgumentChecker(arg_name)
def get_call_args(func, args, kwargs, traceback=None):
try:
return inspect.getcallargs(unwrapper(func), *args, **kwargs)
except TypeError as e:
six.reraise(RQTypeError, RQTypeError(*e.args), traceback)


def apply_rules(*rules):
def decorator(func):
@wraps(func)
def api_rule_check_wrapper(*args, **kwargs):
call_args = None
for r in rules:
if not r.pre_check:
continue
if call_args is None:
call_args = get_call_args(func, args, kwargs)
r.verify(func.__name__, call_args[r.arg_name])

try:
return func(*args, **kwargs)
except RQInvalidArgument:
Expand All @@ -381,14 +401,12 @@ def api_rule_check_wrapper(*args, **kwargs):
exc_info = sys.exc_info()
t, v, tb = exc_info

try:
call_args = inspect.getcallargs(unwrapper(func), *args, **kwargs)
except TypeError as e:
six.reraise(RQTypeError, RQTypeError(*e.args), tb)
return

if call_args is None:
call_args = get_call_args(func, args, kwargs, tb)
try:
for r in rules:
if r.pre_check:
continue
r.verify(func.__name__, call_args[r.arg_name])
except RQInvalidArgument as e:
six.reraise(RQInvalidArgument, e, tb)
Expand Down

0 comments on commit 5494ca4

Please sign in to comment.