# Decorators with other parameters (besides fn)

We can pass other arguments (besides fn) to decorators aswell.
Essentially its a multilayer process. We need an extra layer of function because of the argument we are trying to pass in.

The following code ensures that the first argument is something specific which we can specify in the decorator

In [9]:
from functools import wraps


def ensure_first_arg_is(val):
    def inner(fn):  # This is the actual decorator
        @wraps(fn)
        def wrapper(*args, **kwargs):
            if args and args[0] != val:
                return f"Invalid! First arg needs to be {val}"
            return fn(*args, **kwargs)
        return wrapper
    return inner


@ensure_first_arg_is("burrito")
def fav_foods(*foods):
    return foods


print(fav_foods("burrito", "ice cream"))  

('burrito', 'ice cream')


In [10]:
print(fav_foods("ice cream", "burrito")) 

Invalid! First arg needs to be burrito


In [11]:
@ensure_first_arg_is(10)
def add_to_ten(num1, num2):
    return num1 + num2


print(add_to_ten(10, 12))  

22


In [18]:
print(add_to_ten(1, 2)) 

Invalid! First arg needs to be 10


What we are really doing is:

```func = decorator_with_args(arg)(func)```

where decorator_with_args(arg) returns inner
and then inner(func) is called which returns the wrapper

Example:
Here, we create a decoator that enforces what argument datatypes are allowed when a function is called

In [17]:
def enforce(*types):  # Decorator that takes any number of arguments
    def decorator(f):  # Actual decorator (inner function) as we need argument support for our decorator
        def new_func(*args, **kwargs):  # wrapper function
            # convert args into something mutable
            newargs = []
            for (a, t) in zip(args, types):
                # feel free to have more elaborated conversion:
                newargs.append(t(a))  # cast each argument with type.
            return f(*newargs, **kwargs)
        return new_func
    return decorator


@enforce(str, int)
def repeat_msg(msg, times):
    for time in range(times):
        print(msg)


@enforce(float, float)
def divide(a, b):
    print(a / b)


repeat_msg("hello", '5')
divide('1', '4')

hello
hello
hello
hello
hello
0.25
