Skip to content

Commit

Permalink
Refactored getting argument names.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rahul Kumar committed Jul 4, 2011
1 parent 214a26c commit 0759600
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions augment.py
Expand Up @@ -18,9 +18,12 @@ def get_args_and_name(fn):
(('a', 'b', 'c'), 'foo')
"""
code = fn.func_code
allargs = code.co_varnames[:code.co_argcount]
fn_name = fn.__name__
allargs, fn_name = getattr(fn, '__allargs__', None), \
getattr(fn, '__fnname__', None)
if not allargs:
code = fn.func_code
allargs = code.co_varnames[:code.co_argcount]
fn_name = fn.__name__
return allargs, fn_name

def _propogate_error(errors, handler=None, exception_type=TypeError):
Expand Down Expand Up @@ -62,19 +65,16 @@ def ensure_args(error_handler=None, **rules):
TypeError: Errors in 'foo'. 'b = ab' violates constraint.
"""
def decorator(fn):
allargs, fn_name = getattr(fn, '__allargs__', None), \
getattr(fn, '__fnname__', None)
if not allargs:
allargs, fn_name = get_args_and_name(fn)
allargs, fn_name = get_args_and_name(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
pargs = list(allargs)[:len(args)]
results = check_args(rules, pargs, args, kwargs)
errors = []
for arg_name, arg_val, valid in results:
if not valid:
errors.append("'%s = %s' violates constraint. "
% (arg_name, arg_val))
errors.append("'%s = %s' violates constraint %s. "
% (arg_name, arg_val, rules[arg_name]))
if errors:
fn_info = "Errors in '%s'. " % fn_name
errors.insert(0, fn_info)
Expand Down Expand Up @@ -137,10 +137,7 @@ def ensure_one_of(exclusive=False, **rules):
TypeError: Errors in 'foo'. Only one of '['a', 'b']' must validate.
"""
def decorator(fn):
allargs, fn_name = getattr(fn, '__allargs__', None), \
getattr(fn, '__fnname__', None)
if not allargs:
allargs, fn_name = get_args_and_name(fn)
allargs, fn_name = get_args_and_name(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
pargs = list(allargs)[:len(args)]
Expand All @@ -149,10 +146,12 @@ def wrapper(*args, **kwargs):
if valid])
fn_info = "Errors in '%s'. " % fn_name
if valid_count < 1:
error_msg = "One of '%s' must validate." % rules.keys()
error_msg = "One of '%s' must validate. Constraints: %s" % \
(rules.keys(), rules)
_propogate_error(fn_info + error_msg)
elif valid_count > 1 and exclusive:
error_msg = "Only one of '%s' must validate." % rules.keys()
error_msg = "Only one of '%s' must validate. Constraints: %s" % \
(rules.keys(), rules)
_propogate_error(fn_info + error_msg)
else:
return fn(*args, **kwargs)
Expand All @@ -173,10 +172,7 @@ def transform_args(**rules):
4
"""
def decorator(fn):
allargs, fn_name = getattr(fn, '__allargs__', None), \
getattr(fn, '__fnname__', None)
if not allargs:
allargs, fn_name = get_args_and_name(fn)
allargs, fn_name = get_args_and_name(fn)
@wraps(fn)
def wrapper(*args, **kwargs):
pargs = list(allargs)[:len(args)]
Expand Down

0 comments on commit 0759600

Please sign in to comment.