### Decorator Factories and Optional Arguments

In this video, we're going to look at how we can create decorators that can take optional arguments, and be called with or without the parentheses, much like the `lru_cache` decorator in the Python standard library:

In [1]:
from functools import lru_cache

@lru_cache
def my_func(a):
    pass

@lru_cache(maxsize=5)
def my_other_func(a):
    pass

I am going to assume that you already know what closures and decorators are, and how to create simple decorators.

Let's review that quickly in case you're a bit rusty, otherwise just skip ahead to the section titled [Decorator Arguments](#decorator_arguments).

#### Simple Decorator

For example, this decorator can be used to time any function:

In [2]:
from time import perf_counter

def timer(fn):
    def inner(*args, **kwargs):
        start = perf_counter()
        result = fn(*args, **kwargs)
        end = perf_counter()
        print(f"{fn.__name__}: {end-start:0.5f} seconds")
        return result
    return inner

We can now decorate any function we want to time with it:

In [3]:
from random import random
from time import sleep

@timer
def my_func(a, *, b):
    sleep(random())
    return a * b

We can then call this decorated function:

In [4]:
my_func("*", b=20)

my_func: 0.24645 seconds


'********************'

Also remember that the decorator syntax is just syntactic sugar - we could just as easily have decorated our function this way:

In [5]:
def my_func(a, *, b):
    sleep(random())
    return a * b

my_func = timer(my_func)

In [6]:
my_func("*", b=10)

my_func: 0.58107 seconds


'**********'

So a decorator is nothing more than a function that receives a function as an argument (the decorated function), "modifies" the function's behavior in some way, and returns the modified function. Usually we modify the behavior of the function by adding code before or after calling the original function, and returning the value of the original function call.

But there is one small issue - the decorated `my_func` function is no longer the **original** one - and we lose some of its metadata:

In [7]:
def my_func(a, *, b):
    """A docstring annotation"""
    sleep(random())
    return a * b

In [8]:
my_func.__name__

'my_func'

In [9]:
my_func.__doc__

'A docstring annotation'

Now let's decorate it:

In [10]:
my_func = timer(my_func)

In [11]:
my_func.__name__

'inner'

In [12]:
my_func.__doc__

As you can see the name and docstring have been replaced by the name and docstring of the `inner` function in our decorator function (which makes sense since that is what we return when calling `timer`.

And of course, since it is equivalent, the same happens if we use the `@` decorator syntax:

In [13]:
@timer
def my_func(a, *, b):
    """A docstring annotation"""
    sleep(random())
    return a * b

In [14]:
my_func.__name__

'inner'

In [15]:
my_func.__doc__

#### Using `@wraps` to retain original function metadata

An easy fix is to use the `@wraps` decorator in the `functools` module, which essentially copies some relevant metadata from the source function (`my_func`) to that `inner` function:

In [16]:
from functools import wraps

In [17]:
def timer(fn):
    @wraps(fn)
    def inner(*args, **kwargs):
        start = perf_counter()
        result = fn(*args, **kwargs)
        end = perf_counter()
        print(f"{fn.__name__}: {end-start:0.5f} seconds")
        return result
    return inner

In [18]:
@timer
def my_func(a, *, b):
    """A docstring annotation"""
    sleep(random())
    return a * b

In [19]:
my_func.__name__

'my_func'

In [20]:
my_func.__doc__

'A docstring annotation'

So this a common pattern for simple decorators - we use nested functions, and closures, along with `@wraps` to create simple decorators.


<a id="decorator_arguments"></a>
#### Decorator Arguments

But these decorators do not allow the possibility of passing arguments to the decorating function itself.

If you look at some built-in decorators in Python, you'll notice that some decorators are **parametrized**.

For example the `lru_cache` decorator in the `functools` module:

In [21]:
from functools import lru_cache

@lru_cache(5)
def my_func(x):
    print(f"Cache miss: {x}")
    return x ** 2

And now we can call this decorated function which will cache the last 5 arguments and the corresponding results:

When we call the function with a previously unseen (and hence uncached) argument, we get:

In [22]:
my_func(1)

Cache miss: 1


1

If we call it again, you'll notice we do not see the `print` output - the result was retrieved from the cache:

In [23]:
my_func(1)

1

If we call different arguments, these will be added to the cache, and since our cache is limited to 5, once we add 5 more entries into the cache, that first cache entry for `1` will be lost:

In [24]:
my_func(2)
my_func(3)
my_func(4)
my_func(5)
my_func(6)

Cache miss: 2
Cache miss: 3
Cache miss: 4
Cache miss: 5
Cache miss: 6


36

And if we call the function with `1` again:

In [25]:
my_func(1)

Cache miss: 1


1

So that's the LRU cache decorator - it takes an argument - not something our previous pattern could handle.

To get that functionality, we need to create a function that will take in that parameter, and **create** and return a decorator function - hence the term **decorator factory** - we are going to write a function that creates and returns a decorator, not a decorated function like we did previously.

Let's do a simple example of this - we'll go back to our `timer` example, but somehow we not only want to output the function name and the run time of the function, but also provide some form of categorization that we want included in the output (maybe we are writing these timings to a database, and it would be helpful to group timings under various categories).

In [26]:
def timer(category = "null"):
    def decorator(fn):
        @wraps(fn)
        def inner(*args, **kwargs):
            start = perf_counter()
            result = fn(*args, **kwargs)
            end = perf_counter()
            print(f"{category}: {fn.__name__}: {end-start:0.5f} seconds")
            return result
        return inner
    return decorator            

As you can see, our `timer` function is now no longer a decorator, it is a function that **returns** a decorator when called - so in order to use `timer` as a decorator we have to **call** it to get a decorator function, and that return value can then be used as a decorator in the usual way.

Let's break it down:

In [27]:
dec_timer_section_1 = timer("section 1")

In [28]:
def my_func(a, *, b):
    sleep(random())
    return a * b

my_func = dec_timer_section_1(my_func)

In [29]:
my_func.__name__

'my_func'

In [30]:
my_func(5, b=10)

section 1: my_func: 0.88532 seconds


50

And we can collapse all this code using the `@` syntax:

In [31]:
@timer("section 1")
def my_func(a, *, b):
    sleep(random())
    return a * b

In [32]:
my_func(5, b=10)

section 1: my_func: 0.43780 seconds


50

What happens if we do not want to provide a section? We have a default defined, so we can just do this:

In [33]:
@timer()
def my_func(a, *, b):
    sleep(random())
    return a * b

In [34]:
my_func('*', b=5)

null: my_func: 0.23383 seconds


'*****'

#### Optional Arguments

But the syntax is a bit awkward - we have to remember to use `()` if we do not want to actually pass a value (and thereby use the default value).

If we go back to the `lru_cache` decorator we looked at earlier, the default cache size is `128`.

The way this decorator is implemented, we can simply use the decorator name without passing a size and **without** using the empty `()` - and it will default to a cache size of `128`:

In [35]:
@lru_cache
def my_func(x):
    print(f"Cache miss: {x}")
    return x ** 2

In [36]:
for x in range(128):
    my_func(x)

Cache miss: 0
Cache miss: 1
Cache miss: 2
Cache miss: 3
Cache miss: 4
Cache miss: 5
Cache miss: 6
Cache miss: 7
Cache miss: 8
Cache miss: 9
Cache miss: 10
Cache miss: 11
Cache miss: 12
Cache miss: 13
Cache miss: 14
Cache miss: 15
Cache miss: 16
Cache miss: 17
Cache miss: 18
Cache miss: 19
Cache miss: 20
Cache miss: 21
Cache miss: 22
Cache miss: 23
Cache miss: 24
Cache miss: 25
Cache miss: 26
Cache miss: 27
Cache miss: 28
Cache miss: 29
Cache miss: 30
Cache miss: 31
Cache miss: 32
Cache miss: 33
Cache miss: 34
Cache miss: 35
Cache miss: 36
Cache miss: 37
Cache miss: 38
Cache miss: 39
Cache miss: 40
Cache miss: 41
Cache miss: 42
Cache miss: 43
Cache miss: 44
Cache miss: 45
Cache miss: 46
Cache miss: 47
Cache miss: 48
Cache miss: 49
Cache miss: 50
Cache miss: 51
Cache miss: 52
Cache miss: 53
Cache miss: 54
Cache miss: 55
Cache miss: 56
Cache miss: 57
Cache miss: 58
Cache miss: 59
Cache miss: 60
Cache miss: 61
Cache miss: 62
Cache miss: 63
Cache miss: 64
Cache miss: 65
Cache miss: 66
Cache

In [37]:
for x in range(128):
    my_func(x)

In [38]:
my_func(128)

Cache miss: 128


16384

In [39]:
my_func(0)

Cache miss: 0


0

So, how can we achieve the same with our own decorator? In our previous example, we had to use `()` to use the default `null` value for `section`:

In [40]:
@timer()
def my_func(a, b):
    return a * b

In [41]:
my_func(5, 10)

null: my_func: 0.00000 seconds


50

What we really want is to use something like this:

```python
@timer
def my_func(a, b):
    return a * b
```

as well as 

```python
@timer("section 1")
def my_func(a, b):
    return a * b
```

To do this, let's deconstruct the decorator syntax first, to see what's going on:

In the first variant (no arguments, no call):

```python
def my_func(a, b):
    return a * b

my_func = timer(my_func)
```

But in the second variant:

```python
def my_func(a, b):
    return a * b

my_func = timer("section 1")(my_func)
```

As you can see, in the first variant, the argument received by `timer` will be the **function** we want to decorate, and in the second variant it will be the category **string**.

So our decorator **factory** will need to differentiate between these two states, and modify what it returns - either the decorated function (with the default category of `null` set in the closure) for the first variant, or the decorator itself for the second variant (but with the specified `category` value set in the closure).

We're going to use the built-in function `callable` to determine if that argument is a function (1st variant), or a string (2nd variant):

In [42]:
def timer(func_or_category=None):
    def decorator(fn):
        @wraps(fn)
        def inner(*args, **kwargs):
            start = perf_counter()
            result = fn(*args, **kwargs)
            end = perf_counter()
            print(f"{category}: {fn.__name__}: {end-start:0.5f} seconds")
            return result
        return inner
        
    if callable(func_or_category):
        # a callable was passed in (1st variant)
        func = func_or_category
        category = "null"  # this will be bound to the decorator closure
        return decorator(func)
    elif isinstance(func_or_category, str) or func_or_category is None:
        # a string (or None) was passed (2nd variant)
        category = func_or_category  or "null"  # this will be bound to the decorator closure
        return decorator
    else:
        raise ValueError("Expected argument to be a string, a callable, or None.")


Let's give it a try:

In [43]:
@timer("section 1")
def my_func(a, b):
    return a * b

In [44]:
my_func.__name__

'my_func'

In [45]:
my_func(2, 3)

section 1: my_func: 0.00000 seconds


6

In [46]:
@timer
def my_func(a, b):
    return a * b

In [47]:
my_func.__name__

'my_func'

In [48]:
my_func(3, 4)

null: my_func: 0.00000 seconds


12

And in fact, we can still use the `()` syntax too if we want (but why would we after going to the trouble of not having to? 😀)

In [49]:
@timer()
def my_func(a, b):
    return a * b

my_func(2, 3)

null: my_func: 0.00000 seconds


6

As you can see, we achieved a cleaner syntax, and the code to do so is not particularly over-complicated.

Is it worth the trouble? That's entirely up to you!