# 函数装饰器和闭包

### 把被装饰对象的名称绑定给装饰器返回的对象

In [1]:
def deco(func):
    def inner():
        print('running inner()')
    return inner

In [2]:
@deco
def target():
    print('running target()')

In [3]:
target()

running inner()


In [4]:
target

<function __main__.deco.<locals>.inner()>

In [5]:
registry = []


def register(func):
    print('running register(%s)' % func)
    registry.append(func)
    return func


@register
def f1():
    print('running f1()')


@register
def f2():
    print('running f2()')


def f3():
    print('running f3()')


def main():
    print('running main()')
    print('registry ->', registry)
    f1()
    f2()
    f3()


if __name__ == '__main__':
    main()

running register(<function f1 at 0x7f9fb8246160>)
running register(<function f2 at 0x7f9fb8246310>)
running main()
registry -> [<function f1 at 0x7f9fb8246160>, <function f2 at 0x7f9fb8246310>]
running f1()
running f2()
running f3()


In [6]:
import registration

running register(<function f1 at 0x7f9fb82465e0>)
running register(<function f2 at 0x7f9fb8246280>)


In [7]:
registration.registry

[<function registration.f1()>, <function registration.f2()>]

In [8]:
promos = []


def promotion(promo_func):
    promos.append(promo_func)
    return promo_func


@promotion
def fidelity(order):
    return order.total() * .05 if order.customer.fidelity >= 1000 else 0


@promotion
def bulk_item(order):
    discount = 0
    for item in order.cart:
        if item.quantity >= 20:
            discount = item.total() * .1
    return discount


@promotion
def large_order(order):
    distinct_items = {item.product for item in order.cart}
    if len(distinct_items) >= 10:
        return order.total() * .07
    return 0


def best_promo(order):
    return max(promo(order) for promo in promos)

In [9]:
def f1(a):
    print(a)
    print(b)
    
    
f1(3)

3


NameError: name 'b' is not defined

In [10]:
b = 6
f1(3)

3
6


In [13]:
b = 6


def f2(a):
    print(a)
    # b = 9  
    print(b)
    b = 9
    
    
f2(3)

3


UnboundLocalError: local variable 'b' referenced before assignment

In [14]:
b = 6


def f3(a):
    global b
    print(a)
    print(b)
    b = 9
    
    
f3(3)

3
6


In [15]:
b

9

In [16]:
f3(3)

3
9


In [17]:
b = 30

In [18]:
b

30

In [19]:
from dis import dis

dis(f1)

  2           0 LOAD_GLOBAL              0 (print)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 POP_TOP

  3           8 LOAD_GLOBAL              0 (print)
             10 LOAD_GLOBAL              1 (b)
             12 CALL_FUNCTION            1
             14 POP_TOP
             16 LOAD_CONST               0 (None)
             18 RETURN_VALUE


In [20]:
dis(f2)

  5           0 LOAD_GLOBAL              0 (print)
              2 LOAD_FAST                0 (a)
              4 CALL_FUNCTION            1
              6 POP_TOP

  6           8 LOAD_GLOBAL              0 (print)
             10 LOAD_FAST                1 (b)
             12 CALL_FUNCTION            1
             14 POP_TOP

  7          16 LOAD_CONST               1 (9)
             18 STORE_FAST               1 (b)
             20 LOAD_CONST               0 (None)
             22 RETURN_VALUE


In [21]:
class Averager():
    
    def __init__(self):
        self.series = []
        
    def __call__(self, new_value):
        self.series.append(new_value)
        total = sum(self.series)
        return total/len(self.series)

In [22]:
avg = Averager()
avg(10)

10.0

In [23]:
avg(11)

10.5

In [24]:
avg(12)

11.0

In [25]:
def make_averager():
    series = []
    
    def averager(new_value):
        series.append(new_value)
        total = sum(series)
        return total/len(series)
    
    return averager

In [26]:
avg = make_averager()
avg(10)

10.0

In [27]:
avg(11)

10.5

In [28]:
avg(12)

11.0

In [29]:
avg.__code__.co_varnames

('new_value', 'total')

In [30]:
avg.__code__.co_freevars

('series',)

In [31]:
avg.__closure__

(<cell at 0x7f9fa9232490: list object at 0x7f9fa8ef1dc0>,)

In [32]:
avg.__closure__[0].cell_contents

[10, 11, 12]

In [33]:
def make_averager():
    count = 0
    total = 0
    
    def averager(new_value):
        count += 1
        total += new_value
        return total/count
    
    return averager

In [35]:
avg = make_averager()
avg(10)

UnboundLocalError: local variable 'count' referenced before assignment

In [36]:
def make_averager():
    count = 0
    total = 0
    
    def averager(new_value):
        nonlocal count, total
        count += 1
        total += new_value
        return total/count
    
    return averager

In [37]:
avg = make_averager()
avg(10)

10.0

In [38]:
import time


def clock(func):
    def clocked(*args):
        t0 = time.perf_counter()
        result = func(*args)    # 其中func属于自由变量，因为它是在函数定义体中被引用，但没有在函数定义体中被定义的非全局变量
        elapsed = time.perf_counter() - t0
        name = func.__name__
        arg_str = ', '.join(repr(arg) for arg in args)
        print('[%0.8fs] %s(%s) -> %r' % (elapsed, name, arg_str, result))
        return result
    return clocked

In [39]:
@clock
def snooze(seconds):
    time.sleep(seconds)
    
    
@clock
def factorial(n):
    return 1 if n < 2 else n * factorial(n-1)


if __name__ == '__main__':
    print('*' * 40, 'Calling snooze(.123)')
    snooze(.123)
    print('*' * 40, 'Calling factorial(6)')
    print('6! =', factorial(6))

**************************************** Calling snooze(.123)
[0.12319834s] snooze(0.123) -> None
**************************************** Calling factorial(6)
[0.00000100s] factorial(1) -> 1
[0.00004536s] factorial(2) -> 2
[0.00106242s] factorial(3) -> 6
[0.00117357s] factorial(4) -> 24
[0.00122535s] factorial(5) -> 120
[0.00127038s] factorial(6) -> 720
6! = 720


In [41]:
import clockdeco_demo

clockdeco_demo.factorial.__name__

'clocked'

In [42]:
import time
import functools


def clock(func):
    @functools.wraps(func)
    def clocked(*args, **kwargs):
        t0 = time.time()
        result = func(*args, **kwargs)    # 函数闭包会保留在定义函数时存在的自由变量的绑定，例如本例中的func自由变量
        elapsed = time.time() - t0
        name = func.__name__
        arg_lst = []
        if args:
            arg_lst.append(', '.join(repr(arg) for arg in args))
        if kwargs:
            pairs = ['%s=%r' % (k, w) for k, w in sorted(kwargs.items())]
            arg_lst.append(', '.join(pairs))
        arg_str = ', '.join(arg_lst)
        print('[%0.8fs] %s(%s) -> %r ' % (elapsed, name, arg_str, result))
        return result
    return clocked

In [47]:
@clock
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-2) + fibonacci(n-1)

In [48]:
fibonacci(6)

[0.00000048s] fibonacci(0) -> 0 
[0.00000024s] fibonacci(1) -> 1 
[0.00006342s] fibonacci(2) -> 1 
[0.00000024s] fibonacci(1) -> 1 
[0.00000024s] fibonacci(0) -> 0 
[0.00000000s] fibonacci(1) -> 1 
[0.00005722s] fibonacci(2) -> 1 
[0.00014997s] fibonacci(3) -> 2 
[0.00023150s] fibonacci(4) -> 3 
[0.00000000s] fibonacci(1) -> 1 
[0.00000024s] fibonacci(0) -> 0 
[0.00000024s] fibonacci(1) -> 1 
[0.00002861s] fibonacci(2) -> 1 
[0.00008368s] fibonacci(3) -> 2 
[0.00000024s] fibonacci(0) -> 0 
[0.00000000s] fibonacci(1) -> 1 
[0.00001884s] fibonacci(2) -> 1 
[0.00000024s] fibonacci(1) -> 1 
[0.00000000s] fibonacci(0) -> 0 
[0.00000024s] fibonacci(1) -> 1 
[0.00001884s] fibonacci(2) -> 1 
[0.00003719s] fibonacci(3) -> 2 
[0.00007439s] fibonacci(4) -> 3 
[0.00017786s] fibonacci(5) -> 5 
[0.00042558s] fibonacci(6) -> 8 


8

In [52]:
import functools

@functools.lru_cache()
@clock
def fibonacci(n):
    if n < 2:
        return n
    return fibonacci(n-2) + fibonacci(n-1)


if __name__ == '__main__':
    print(fibonacci(6))

[0.00000072s] fibonacci(0) -> 0 
[0.00000143s] fibonacci(1) -> 1 
[0.00025439s] fibonacci(2) -> 1 
[0.00000238s] fibonacci(3) -> 2 
[0.00043225s] fibonacci(4) -> 3 
[0.00000191s] fibonacci(5) -> 5 
[0.00060391s] fibonacci(6) -> 8 
8


In [53]:
fibonacci(30)

[0.00000048s] fibonacci(7) -> 13 
[0.00005698s] fibonacci(8) -> 21 
[0.00000048s] fibonacci(9) -> 34 
[0.00024724s] fibonacci(10) -> 55 
[0.00000024s] fibonacci(11) -> 89 
[0.00027108s] fibonacci(12) -> 144 
[0.00000024s] fibonacci(13) -> 233 
[0.00064158s] fibonacci(14) -> 377 
[0.00000072s] fibonacci(15) -> 610 
[0.00068283s] fibonacci(16) -> 987 
[0.00000024s] fibonacci(17) -> 1597 
[0.00070357s] fibonacci(18) -> 2584 
[0.00000024s] fibonacci(19) -> 4181 
[0.00072289s] fibonacci(20) -> 6765 
[0.00000024s] fibonacci(21) -> 10946 
[0.00074291s] fibonacci(22) -> 17711 
[0.00000024s] fibonacci(23) -> 28657 
[0.00076318s] fibonacci(24) -> 46368 
[0.00000024s] fibonacci(25) -> 75025 
[0.00078249s] fibonacci(26) -> 121393 
[0.00000024s] fibonacci(27) -> 196418 
[0.00080180s] fibonacci(28) -> 317811 
[0.00000024s] fibonacci(29) -> 514229 
[0.00082302s] fibonacci(30) -> 832040 


832040

In [54]:
import html

def htmlize(obj):
    content = html.escape(repr(obj))
    return '<pre>{}<\pre>'.format(content)

In [2]:
from functools import singledispatch
from collections import abc
import numbers
import html


@singledispatch    # 单分派泛函数，使用@singledispatch标记一个（处理object类型）基函数
def htmlize(obj):
    content = html.escape(repr(obj))
    return '<pre>{}</pre>'.format(content)


@htmlize.register(str)
def _(text):
    content = html.escape(text).replace('\n', '<br>\n')
    return '<p>{0}</p>'.format(content)


@htmlize.register(numbers.Integral)
def _(n):
    return '<pre>{0} {0x{0:x}}</pre>'.format(content)


@htmlize.register(tuple)
@htmlize.register(abc.MutableSequence)
def _(seq):
    inner = '</li>\n<li>'.join(htmlize(item) for item in seq)
    return '<ul>\n<li>' + inner + '</li>\n</ul>'

In [3]:
htmlize({1, 2, 3})

'<pre>{1, 2, 3}</pre>'

In [4]:
htmlize(abs)

'<pre>&lt;built-in function abs&gt;</pre>'

In [5]:
registry = []


def register(func):
    print('running register(%s)' % func)
    registry.append(func)
    return func


@register
def f1():
    print('running f1()')
    
    
print('running main()')
print('registry ->', registry)
f1()

running register(<function f1 at 0x7fe13d312430>)
running main()
registry -> [<function f1 at 0x7fe13d312430>]
running f1()


In [6]:
registry = set()


def register(active=True):
    def decorate(func):
        print('running register(active=%s)->decorate(%s)' % (active, func))
        if active:
            registry.add(func)
        else:
            registry.discard(func)
        return func
    
    return decorate


@register(active=False)
def f1():
    print('running f1()')
    
    
@register()
def f2():
    print('running f2()')
    
    
def f3():
    print('running f3()')

running register(active=False)->decorate(<function f1 at 0x7fe13ca71e50>)
running register(active=True)->decorate(<function f2 at 0x7fe13d312b80>)


In [7]:
registry

{<function __main__.f2()>}

In [9]:
register()(f3)

running register(active=True)->decorate(<function f3 at 0x7fe13ca718b0>)


<function __main__.f3()>

In [10]:
registry

{<function __main__.f2()>, <function __main__.f3()>}

In [11]:
register(active=False)(f2)

running register(active=False)->decorate(<function f2 at 0x7fe13d312b80>)


<function __main__.f2()>

In [12]:
registry

{<function __main__.f3()>}

In [14]:
import time

DEFAULT_FMT = '[{elapsed:0.8f}s] {name}({args}) -> {result}'


def clock(fmt=DEFAULT_FMT):
    def decorate(func):
        def clocked(*_args):
            t0 = time.time()
            _result = func(*_args)
            elapsed = time.time() - t0
            name = func.__name__
            args = ', '.join(repr(arg) for arg in _args)
            result = repr(_result)
            print(fmt.format(**locals()))
            return _result
        
        return clocked
    
    return decorate


@clock()
def snooze(seconds):
    time.sleep(seconds)
    

for i in range(3):
    snooze(i)

[0.00000572s] snooze(0) -> None
[1.00129914s] snooze(1) -> None
[2.00210810s] snooze(2) -> None


In [16]:
@clock('{name}: {elapsed}')
def snooze(seconds):
    time.sleep(seconds)
    
    
for i in range(3):
    snooze(i)

snooze: 2.002716064453125e-05
snooze: 1.0013184547424316
snooze: 2.002094268798828


In [17]:
@clock('{name}({args}) dt={elapsed:0.3f}s')
def snooze(seconds):
    time.sleep(seconds)
    
    
for i in range(3):
    snooze(i)

snooze(0) dt=0.000s
snooze(1) dt=1.001s
snooze(2) dt=2.002s
