# Python程序设计进阶

## 装饰器、上下文管理器

### 装饰器

装饰器 decorator：修改函数行为的函数

本质上是一个函数：
- 其输入是一个函数（`callable`）
  - 函数
  - 方法
  - 类
  - 实现`__call__`的实体

- 输出将被用于替换这一个函数
- 可能产生副作用
  - 对非局部变量造成改变
  - 每次访问函数得到结果可能不同


In [1]:
def decorator(func):
    def wrapper():
        print('Before')
        func()
        print('After')
    return wrapper

def foo():
    print('Hello')

foo = decorator(foo)

foo()

Before
Hello
After


In [3]:
@decorator
def bar():
    print('World')

bar()

Before
World
After


例：函数计时器

In [5]:
def fib1(n):
    def fibr(n):
        if n < 2:
            return n
        return fibr(n-1) + fibr(n-2)
    return fibr(n)

def fib2(n):
    a, b = 0, 1
    for _ in range(n):
        a, b = b, a+b
    return a

assert fib1(42) == fib2(42)

In [7]:
import time

start = time.time()
fib2(42)
end = time.time()
print(f"{fib2.__name__} took {end - start} seconds")

fib2 took 0.0 seconds


In [9]:
import time

def timeit(fn):
    def wrapper(*args, **kwargs):
        start = time.time()
        result = fn(*args, **kwargs)
        end = time.time()
        print(f"{fn.__name__} took {end - start} seconds")
        return result
    return wrapper

In [11]:
@timeit
def fib1(n):
    def fibr(n):
        if n < 2:
            return n
        return fibr(n-1) + fibr(n-2)
    return fibr(n)

@timeit
def fib2(n):
    a, b = 0, 1
    for _ in range(n):
        a, b = b, a+b
    return a

In [13]:
fib2(42)

fib2 took 0.0 seconds


267914296

例：函数缓存器（懒人动态规划）

In [15]:
function_cache = {}

def cacheit(fn):
    def wrapper(*args, **kwargs):
        key = (fn, args, tuple(kwargs.items()))
        if key not in function_cache:
            function_cache[key] = fn(*args, **kwargs)
        return function_cache[key]
    return wrapper

In [17]:
@timeit
def fib3(n):
    @cacheit
    def fibr(n):
        if n < 2:
            return n
        return fibr(n-1) + fibr(n-2)
    return fibr(n)

fib3(42)

fib3 took 0.0 seconds


267914296

副作用：非局部状态的改变

`cacheit`例子中，调用可能会导致`function_cache`改变（可能增加一条记录），既副作用。

多个装饰器的执行顺序：从下到上

In [19]:
@timeit
@decorator
def bar():
    print('World')

bar()

Before
World
After
wrapper took 0.0 seconds


含参数的装饰器

In [21]:
def param_decorator(param):
    def decorator(func):
        def wrapper():
            print(f'Before: {param}')
            func()
            print(f'After: {param}')
        return wrapper
    return decorator

@param_decorator('Hello')
def bar():
    print('World')

bar()

Before: Hello
World
After: Hello


### 上下文管理器

- 进入`with`语句下的代码块时执行一个操作（如，获取资源）
- 离开代码块时（如，释放资源）
- 使用`as`关键词将资源绑定到一个变量上

例：文件读写管理

In [None]:
fout = open("Hello.txt", "w")
fout.write("Hello World!")
fout.close()

In [None]:
with open("Hello.txt", "w") as fout:
    fout.write("Hello World!")

魔法方法实现

In [None]:
class OpenMy:
    prefix = "guoquan-"
    
    def __init__(self, filename, mode="r"):
        self.filename = filename
        self.mode = mode
    
    def __enter__(self):
        print("Enter")
        self.fout = open(self.prefix + self.filename, self.mode)
        return self.fout
    
    def __exit__(self, *args):
        print("Exit")
        self.fout.close()

In [None]:
with OpenMy("Hello.txt", "w") as fout:
    fout.write("Hello World!")

装饰器实现

In [None]:
from contextlib import contextmanager

prefix = "鸭梨-"

@contextmanager
def open_my(filename, mode="r"):
    print('Enter')
    file = open(f"{prefix}{filename}", mode)
    yield file
    print('Exit')
    file.close()

In [None]:
with open_my("Hello.txt", "w") as fout:
    fout.write("Hello World!")

例：配置管理

In [None]:
from contextlib import contextmanager

class Speech:
    speaker = "guoquan"

    def __call__(self, message):
        print(f"{self.speaker}: {message}")

say = Speech()

@contextmanager
def interrupt(name):
    orig_speaker = Speech.speaker
    Speech.speaker = name
    yield
    Speech.speaker = orig_speaker

say("Hello")
with interrupt("鸭梨"):
    say("汪汪")
say("World")

## 多继承、元类

### 多继承

In [None]:
class Pet:
    def __init__(self, name):
        self.name = name

    def speak(self):
        print(f"我是{self.name}")

class Dog(Pet):
    woof = "汪"

    def speak(self):
        super().speak()
        print(self.woof)

In [None]:
class LongLegDog(Dog):
    def speak(self):
        super().speak()
        print("抖抖长腿")

class YellowHairDog(Dog):
    def speak(self):
        super().speak()
        print("甩甩金毛")

class Shiba(LongLegDog, YellowHairDog):
    def speak(self):
        super().speak()
        print("嗷呜嗷呜")

In [None]:
Shiba("阿柴").speak()

### 元类

类的类

例：修改类创建的行为

In [None]:
class DisableMultipleInherit(type):
    def __new__(cls, name, bases, attrs):
        if len(bases) > 1:
            raise TypeError("Multiple Inheritance is not allowed")
        return super().__new__(cls, name, bases, attrs)
    
class PurebredDog(Dog, metaclass=DisableMultipleInherit):
    pass

class Husky(PurebredDog):
    pass

class Poodle(PurebredDog):
    pass

class NewBred(Poodle, Husky):  # TypeError: Multiple Inheritance is not allowed
    pass

类的创建过程既其元类的实例化过程

例：注入操作

In [None]:
class InjectTimer(type):
    def __init__(cls, name, bases, attrs):
        super().__init__(name, bases, attrs)
        for key, value in attrs.items():
            if callable(value):
                setattr(cls, key, cls.inject(value))

    @staticmethod
    def inject(fn):
        return timeit(fn)

class Fib(metaclass=InjectTimer):
    def fib1(self, n):
        def fibr(n):
            if n < 2:
                return n
            return fibr(n-1) + fibr(n-2)
        return fibr(n)

    def fib2(self, n):
        a, b = 0, 1
        for _ in range(n):
            a, b = b, a+b
        return a

fib = Fib()

fib.fib2(42)

## 循环、递归、迭代器

In [None]:
例：累加器

循环

In [None]:
def accumulate(n):
    total = 0
    for value in range(n+1):
        total = total + value
    return total

accumulate(100)

递归

In [None]:
def accumulate_r(n):
    if n == 1:
        return 1
    else:
        total = accumulate_r(n-1)
        total = total + n
        return total

accumulate_r(100)

尾递归

In [None]:
def accumulate_t(n, total=0):
    if n == 0:
        return total
    else:
        return accumulate_t(n-1, total+n)

accumulate_t(100)

消除尾递归

In [None]:
def accumulate_e(n, total=0):
    while True:
        if n == 0:
            return total
        else:
            n, total = n-1, total+n

accumulate_e(100)

例：汉诺塔

In [None]:
def hanoi_r(n, src, dst, tmp):
    if n == 1:
        print(f"Move {src} -> {dst}")
    else:
        hanoi_r(n-1, src, tmp, dst)
        print(f"Move {src} -> {dst}")
        hanoi_r(n-1, tmp, dst, src)
        
hanoi_r(3, 'A', 'C', 'B')

In [None]:
def hanoi_l(n, src, dst, tmp):
    stack = [(n, src, dst, tmp)]
    while stack:
        n, src, dst, tmp = stack.pop()
        if n == 1:
            print(f"Move {src} -> {dst}")
        else:
            stack.append((n-1, tmp, dst, src))
            stack.append((1, src, dst, tmp))
            stack.append((n-1, src, tmp, dst))

hanoi_l(3, 'A', 'C', 'B')

例：哑谜机的齿轮组合

In [None]:
A = ord("A")
for r1 in range(26):
    for r2 in range(26):
        for r3 in range(26):
            print(f"{chr(r1+A)}{chr(r2+A)}{chr(r3+A)}")

### 迭代器

In [None]:
def display_gen():
    A = ord("A")
    for r1 in range(26):
        for r2 in range(26):
            for r3 in range(26):
                yield f"{chr(r1+A)}{chr(r2+A)}{chr(r3+A)}"

for i, display in enumerate(display_gen()):
    print(i, display)

如果需要更多的齿轮？如何支持不同数量的齿轮？

In [None]:
def display_gen():
    A = ord("A")
    for r1 in range(26):
        for r2 in range(26):
            for r3 in range(26):
                for r4 in range(26):
                    for r5 in range(26):
                        pass


In [None]:
def display_gen_rs(n=3):
    if n == 0:
        yield ""
    else:
        A = ord("A")
        for r in range(26):
            yield from (chr(r+A) + suffix for suffix in display_gen_rs(n-1))

for i, display in enumerate(display_gen_rs()):
    print(i, display)

In [None]:
def display_gen_rp(n=3, prefix=""):
    if n == 0:
        yield prefix
    else:
        A = ord("A")
        for r in range(26):
            yield from display_gen_rp(n-1, prefix + chr(r+A))

for i, display in enumerate(display_gen_rp()):
    print(i, display)

In [None]:
from itertools import product

for i, display in enumerate(product((chr(r+A) for r in range(26)), repeat=3)):
    print(i, "".join(display))

## 其他常见概念