# 元编程
---

### 定义一个装饰器

In [1]:
import time
from functools import wraps


def timethis(func):
    """ 计算函数运行时间的装饰器 """
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(func.__name__, end-start)
        return result
    return wrapper


@timethis
def countdown(n):
    while n > 0:
        n -= 1
    print(n)

In [2]:
countdown(100000)

0
countdown 0.006216764450073242


---

### 解除装饰器访问原函数

In [3]:
countdown.__wrapped__(100000)

0


---

### 定义一个带参数的装饰器

In [4]:
from functools import wraps
import logging


def logged(level, name=None, message=None):
    def decorate(func):
        logname = name if name else func.__module__
        log = logging.getLogger(logname)
        logmsg = message if message else func.__name__

        @wraps(func)
        def wrapper(*args, **kwargs):
            log.log(level, logmsg)
            return func(*args, **kwargs)
        return wrapper
    return decorate


@logged(logging.DEBUG)
def add(x, y):
    return x + y


@logged(logging.CRITICAL, 'example')
def spam():
    print('Spam!')


logging.basicConfig(level=logging.DEBUG)

In [5]:
add(2, 3)

DEBUG:__main__:add


5

---

### 可自定义属性的装饰器

In [6]:
from functools import wraps, partial
import logging


def attach_wrapper(obj, func=None):
    if func is None:
        return partial(attach_wrapper, obj)
    setattr(obj, func.__name__, func)
    return func


def logged(level, name=None, message=None):
    def decorate(func):
        logname = name if name else func.__module__
        log = logging.getLogger(logname)
        logmsg = message if message else func.__name__

        @wraps(func)
        def wrapper(*args, **kwargs):
            log.log(level, logmsg)
            return func(*args, **kwargs)

        @attach_wrapper(wrapper)
        def set_level(newlevel):
            nonlocal level
            level = newlevel

        @attach_wrapper(wrapper)
        def set_message(newmsg):
            nonlocal logmsg
            logmsg = newmsg

        return wrapper

    return decorate


@logged(logging.DEBUG)
def add(x, y):
    return x + y


@logged(logging.CRITICAL, 'example')
def spam():
    print('Spam!')


logging.basicConfig(level=logging.DEBUG)

In [7]:
add(2, 3)

DEBUG:__main__:add


5

In [8]:
add.set_message('Add called')

In [9]:
add(2, 3)

DEBUG:__main__:Add called


5

In [10]:
add.set_level(logging.WARNING)

In [11]:
add(2, 3)



5

---

### 带可选参数的装饰器

In [12]:
from functools import wraps, partial
import logging


def logged(func=None, *, level=logging.DEBUG, name=None, message=None):
    if func is None:
        return partial(logged, level=level, name=name, message=message)

    logname = name if name else func.__module__
    log = logging.getLogger(logname)
    logmsg = message if message else func.__name__

    @wraps(func)
    def wrapper(*args, **kwargs):
        log.log(level, logmsg)
        return func(*args, **kwargs)

    return wrapper


@logged
def add(x, y):
    return x + y


@logged(level=logging.CRITICAL, name='example')
def spam():
    print('Spam!')

---

### 利用装饰器强制函数上的类型检查

In [13]:
from inspect import signature
from functools import wraps


def typeassert(*ty_args, **ty_kwargs):
    def decorate(func):
        # 当使用 -O 或 -OO 参数的优化模式执行程序时，__debug__ 是 False
        if not __debug__:
            return func

        # signature 可以提取一个可调用对象的参数签名信息
        sig = signature(func)
        # 使用 bind_partial 方法来构成对象参数名与传入类型的绑定（忽略缺失的参数）
        bound_types = sig.bind_partial(*ty_args, **ty_kwargs).arguments

        @wraps(func)
        def wrapper(*args, **kwargs):
            # bind 方法和 bind_partial 方法类似，但不允许忽略任何参数
            bound_values = sig.bind(*args, **kwargs)
            for name, value in bound_values.arguments.items():
                if name in bound_types:
                    if not isinstance(value, bound_types[name]):
                        raise TypeError(
                            'Argument {} must be {}'.format(
                                name, bound_types[name])
                            )
            return func(*args, **kwargs)
        return wrapper
    return decorate


@typeassert(int, z=int)
def spam(x, y, z=42):
    print(x, y, z)

In [14]:
spam(1, 2, 3)

1 2 3


In [15]:
try:
    spam(1, 'hello', 'world')
except TypeError as e:
    print('TypeError:', e)

TypeError: Argument z must be <class 'int'>


---

### 在类中定义装饰器

In [16]:
from functools import wraps


class A:

    def decorator1(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            print('Decorator 1')
            return func(*args, **kwargs)
        return wrapper

    @classmethod
    def decorator2(cls, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            print('Decorator 2')
            return func(*args, **kwargs)
        return wrapper

In [17]:
# 装饰函数
a = A()

@a.decorator1
def spam():
    pass

In [18]:
# 装饰类的方法
class B:
    @A.decorator2
    def bar(self):
        pass


# 涉及继承时，必须用父类名显性调用
class C(A):
    @A.decorator2
    def bar(self):
        pass

---

### 将装饰器定义为类

In [19]:
import types
from functools import wraps


class Profiled:

    def __init__(self, func):
        wraps(func)(self)
        self.ncalls = 0

    def __call__(self, *args, **kwargs):
        self.ncalls += 1
        return self.__wrapped__(*args, **kwargs)

    def __get__(self, instance, cls):
        if instance is None:
            return self
        else:
            # 手动创建一个绑定方法来使用
            return types.MethodType(self, instance)

In [20]:
@Profiled
def add(x, y):
    return x + y

In [21]:
add(2, 3)

5

In [22]:
add(4, 5)

9

In [23]:
add.ncalls

2

In [24]:
class Spam:
    @Profiled
    def bar(self, x):
        print(self, x)

s = Spam()

In [25]:
s.bar(1)

<__main__.Spam object at 0x110171550> 1


In [26]:
s.bar(2)

<__main__.Spam object at 0x110171550> 2


In [27]:
s.bar(3)

<__main__.Spam object at 0x110171550> 3


In [28]:
Spam.bar.ncalls

3

使用闭包实现同样功能

In [29]:
import types
from functools import wraps

def profiled(func):
    ncalls = 0
    @wraps(func)
    def wrapper(*args, **kwargs):
        nonlocal ncalls
        ncalls += 1
        return func(*args, **kwargs)
    wrapper.ncalls = lambda: ncalls
    return wrapper


class Spam:
    @profiled
    def bar(self, x):
        print(self, x)

In [30]:
s = Spam()
s.bar(1)

<__main__.Spam object at 0x110171e10> 1


In [31]:
s.bar(2)

<__main__.Spam object at 0x110171e10> 2


In [32]:
Spam.bar.ncalls()

2

---

### 装饰器为被包装函数增加参数

In [33]:
from functools import wraps
import inspect

def optional_debug(func):
    if 'debug' in inspect.getfullargspec(func).args:
        raise TypeError('debug argument already defined')

    @wraps(func)
    def wrapper(*args, debug=False, **kwargs):
        if debug:
            print('Calling', func.__name__)
        return func(*args, **kwargs)

    sig = inspect.signature(func)
    parms = list(sig.parameters.values())
    parms.append(inspect.Parameter('debug',
                inspect.Parameter.KEYWORD_ONLY,
                default=False))
    wrapper.__signature__ = sig.replace(parameters=parms)
    return wrapper


@optional_debug
def add(x,y):
    return x+y

In [34]:
add(2, 3, debug=True)

Calling add


5

In [35]:
print(inspect.signature(add))

(x, y, *, debug=False)


---

### 使用装饰器扩充实例方法的功能

In [36]:
def log_getattribute(cls):
    orig_getattribute = cls.__getattribute__

    def new_getattribute(self, name):
        print('getting:', name)
        return orig_getattribute(self, name)

    cls.__getattribute__ = new_getattribute
    return cls


@log_getattribute
class A:
    def __init__(self,x):
        self.x = x
    def spam(self):
        pass

In [37]:
a = A(42)

In [38]:
a.x

getting: x


42

In [39]:
a.spam()

getting: spam


---

### 使用元类控制实例的创建

阻止创建类实例

In [40]:
class NoInstances(type):

    def __call__(self, *args, **kwargs):
        raise TypeError("Can't instantiate directly")


class Spam(metaclass=NoInstances):

    @staticmethod
    def grok(x):
        print('Spam.grok')

In [41]:
Spam.grok(42)

Spam.grok


In [42]:
try:
    Spam()
except TypeError as e:
    print('TypeError:', e)

TypeError: Can't instantiate directly


实现单例

In [43]:
class Singleton(type):

    def __init__(self, *args, **kwargs):
        self.__instance = None
        super().__init__(*args, **kwargs)

    def __call__(self, *args, **kwargs):
        if self.__instance is None:
            self.__instance = super().__call__(*args, **kwargs)
            return self.__instance
        else:
            return self.__instance


class Spam(metaclass=Singleton):

    def __init__(self):
        print('Creating Spam')

In [44]:
a = Spam()

getting: __class__
getting: __class__
Creating Spam


In [45]:
b = Spam()

In [46]:
a is b

True

In [47]:
from collections import OrderedDict


class Typed:
    _expected_type = type(None)

    def __init__(self, name=None):
        self._name = name

    def __set__(self, instance, value):
        if not isinstance(value, self._expected_type):
            raise TypeError('Expected ' + str(self._expected_type))
        instance.__dict__[self._name] = value


class Integer(Typed):
    _expected_type = int


class Float(Typed):
    _expected_type = float


class String(Typed):
    _expected_type = str


class OrderedMeta(type):

    def __new__(cls, clsname, bases, clsdict):
        """
        从类字典中捕获生成的有序名称，并放入类属性 _order 中；
        这里的 clsdict 参数就是 __prepare__ 返回的映射对象
        """
        d = dict(clsdict)
        order = []
        for name, value in clsdict.items():
            if isinstance(value, Typed):
                value._name = name
                order.append(name)
        d['_order'] = order
        return type.__new__(cls, clsname, bases, d)

    @classmethod
    def __prepare__(cls, clsname, bases):
        """
        这个方法会在开始定义类和它的父类的时候被执行，
        它必须返回一个映射对象以便在类定义体中被使用到，
        这里通过返回一个 OrderedDict 来确保定义的顺序
        """
        return OrderedDict()


class Model(metaclass=OrderedMeta):

    def as_csv(self):
        """ 利用元类 """
        return ','.join(
            str(getattr(self,name)) for name in self._order
        )


class Stock(Model):
    name = String()
    shares = Integer()
    price = Float()

    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

In [48]:
s = Stock('GOOG', 100, 490.1)

In [49]:
s.name

'GOOG'

In [50]:
s.as_csv()

'GOOG,100,490.1'

---

### 定义有可选参数的元类

In [51]:
from abc import ABCMeta, abstractmethod


class IStream(metaclass=ABCMeta):
    @abstractmethod
    def read(self, maxsize=None):
        pass

    @abstractmethod
    def write(self, data):
        pass


class MyMeta(type):

    # __prepare__ 方法默认接受任意的关键字参数，所以不是必须定义
    @classmethod
    def __prepare__(cls, name, bases, *, debug=False, synchronize=False):
        pass
        return super().__prepare__(name, bases)

    def __new__(cls, name, bases, ns, *, debug=False, synchronize=False):
        pass
        return super().__new__(cls, name, bases, ns)

    def __init__(self, name, bases, ns, *, debug=False, synchronize=False):
        pass
        super().__init__(name, bases, ns)


class Spam(metaclass=MyMeta, debug=True, synchronize=True):
    pass


# 或者这样使用，但会占用类的命名空间
class Spam(metaclass=MyMeta):
    debug = True
    synchronize = True
    pass

---

### *args 和 **kwargs 的强制参数签名

In [52]:
from inspect import Signature, Parameter

def make_sig(*names):
    parms = [Parameter(name, Parameter.POSITIONAL_OR_KEYWORD)
            for name in names]
    return Signature(parms)


class StructureMeta(type):

    def __new__(cls, clsname, bases, clsdict):
        clsdict['__signature__'] = make_sig(*clsdict.get('_fields',[]))
        return super().__new__(cls, clsname, bases, clsdict)


class Structure(metaclass=StructureMeta):
    _fields = []

    def __init__(self, *args, **kwargs):
        bound_values = self.__signature__.bind(*args, **kwargs)
        for name, value in bound_values.arguments.items():
            setattr(self, name, value)


class Stock(Structure):
    _fields = ['name', 'shares', 'price']

In [53]:
s1 = Stock('ACME', 100, 490.1)

In [54]:
try:
    s2 = Stock('ACME', 100)
except TypeError as e:
    print('TypeError:', e)

TypeError: missing a required argument: 'price'


In [55]:
try:
    s3 = Stock('ACME', 100, 490.1, shares=50)
except TypeError as e:
    print('TypeError:', e)

TypeError: multiple values for argument 'shares'


---

### 在类上强制使用编程规约

阻止使用混合大小写作为方法名

In [56]:
class NoMixedCaseMeta(type):
    def __new__(cls, clsname, bases, clsdict):
        for name in clsdict:
            if name.lower() != name:
                raise TypeError('Bad attribute name: ' + name)
        return super().__new__(cls, clsname, bases, clsdict)


class Root(metaclass=NoMixedCaseMeta):
    pass

In [57]:
class A(Root):

    def foo_bar(self):
        pass

In [58]:
try:
    class B(Root):

        def fooBar(self):
            pass

except TypeError as e:
    print('TypeError:', e)

TypeError: Bad attribute name: fooBar


检查子类重载方法，若调用参数与父类原始方法的参数签名不同，则发出警告

In [59]:
from inspect import signature
import logging


class MatchSignaturesMeta(type):

    def __init__(self, clsname, bases, clsdict):
        super().__init__(clsname, bases, clsdict)
        sup = super(self, self)
        for name, value in clsdict.items():
            if name.startswith('_') or not callable(value):
                continue
            prev_dfn = getattr(sup,name,None)
            if prev_dfn:
                prev_sig = signature(prev_dfn)
                val_sig = signature(value)
                if prev_sig != val_sig:
                    logging.warning('Signature mismatch in %s. %s != %s',
                                    value.__qualname__, prev_sig, val_sig)


class Root(metaclass=MatchSignaturesMeta):
    pass

In [60]:
class A(Root):

    def foo(self, x, y):
        pass

    def spam(self, x, *, z):
        pass

In [61]:
class B(A):

    def foo(self, a, b):
        pass

    def spam(self, x, z):
        pass



---

### 以编程方式定义类

使用参数新建类

In [62]:
# stock.py
import types

def __init__(self, name, shares, price):
    self.name = name
    self.shares = shares
    self.price = price
def cost(self):
    return self.shares * self.price

cls_dict = {
    '__init__' : __init__,
    'cost' : cost,
}

# 第二个参数是父类元组，第三个是类属性，第四个是 __prepare__ 方法返回的任意对象
Stock = types.new_class(
    'Stock', (), {}, lambda ns: ns.update(cls_dict))
Stock.__module__ = __name__

In [63]:
s = Stock('ACME', 50, 91.1)

In [64]:
s

<__main__.Stock at 0x1101831d0>

In [65]:
s.cost()

4555.0

---

### 在定义的时候初始化类的成员

实现类似于命令元组的类

In [66]:
import operator


class StructTupleMeta(type):

    def __init__(cls, *args, **kwargs):
        super().__init__(*args, **kwargs)
        for n, name in enumerate(cls._fields):
            # itemgetter 创建一个访问器函数，然后 property 函数将其转换成属性
            setattr(cls, name, property(operator.itemgetter(n)))


class StructTuple(tuple, metaclass=StructTupleMeta):
    _fields = []

    def __new__(cls, *args):
        if len(args) != len(cls._fields):
            raise ValueError(
                '{} arguments required'.format(len(cls._fields)))
        return super().__new__(cls, args)

In [67]:
class Stock(StructTuple):
    _fields = ['name', 'shares', 'price']


class Point(StructTuple):
    _fields = ['x', 'y']

In [68]:
s = Stock('ACME', 50, 91.1)

In [69]:
s

('ACME', 50, 91.1)

In [70]:
s[0]

'ACME'

In [71]:
s.shares * s.price

4555.0

In [72]:
try:
    s.shares = 23
except AttributeError as e:
    print('AttributeError:', e)

AttributeError: can't set attribute


---

### 利用函数注解实现方法重载

使用元类的实现，但不支持关键字参数，继承方面也有限制

In [73]:
import inspect
import types


class MultiMethod:
    '''
    实现方法重载，允许同名但不同参数类型的方法
    '''
    def __init__(self, name):
        self._methods = {}
        self.__name__ = name

    def register(self, meth):
        '''
        注册一个新方法作为重载方法
        '''
        sig = inspect.signature(meth)

        # 从方法的参数签名中绑定类型签名
        types = []
        for name, parm in sig.parameters.items():
            if name == 'self':
                continue
            if parm.annotation is inspect.Parameter.empty:
                raise TypeError(
                    'Argument {} must be annotated with a type'.format(name)
                )
            if not isinstance(parm.annotation, type):
                raise TypeError(
                    'Argument {} annotation must be a type'.format(name)
                )
            if parm.default is not inspect.Parameter.empty:
                self._methods[tuple(types)] = meth
            types.append(parm.annotation)

        self._methods[tuple(types)] = meth

    def __call__(self, *args):
        types = tuple(type(arg) for arg in args[1:])
        meth = self._methods.get(types, None)
        if meth:
            return meth(*args)
        else:
            raise TypeError('No matching method for types {}'.format(types))

    def __get__(self, instance, cls):
        if instance is not None:
            return types.MethodType(self, instance)
        else:
            return self


class MultiDict(dict):
    '''
    实现一个特殊字典用于在元类中构建重载方法
    '''
    def __setitem__(self, key, value):
        if key in self:
            # 如果 key 已经存在，则其必须是重载方法或者注册为重载方法
            current_value = self[key]
            if isinstance(current_value, MultiMethod):
                current_value.register(value)
            else:
                mvalue = MultiMethod(key)
                mvalue.register(current_value)
                mvalue.register(value)
                super().__setitem__(key, mvalue)
        else:
            super().__setitem__(key, value)


class MultipleMeta(type):
    '''
    实现元类允许类定义重载方法
    '''
    def __new__(cls, clsname, bases, clsdict):
        return type.__new__(cls, clsname, bases, dict(clsdict))

    @classmethod
    def __prepare__(cls, clsname, bases):
        return MultiDict()

In [74]:
import time


class Date(metaclass=MultipleMeta):

    def __init__(self, year: int, month:int, day:int):
        self.year = year
        self.month = month
        self.day = day

    def __init__(self):
        t = time.localtime()
        self.__init__(t.tm_year, t.tm_mon, t.tm_mday)

    def __repr__(self):
        return 'Date({0.year!r}, {0.month!r}, {0.day!r})'.format(self)

In [75]:
Date(2012, 12, 21)

Date(2012, 12, 21)

In [76]:
Date()

Date(2018, 10, 31)

In [77]:
try:
    Date('2012', '12', '21')
except TypeError as e:
    print('TypeError:', e)

TypeError: No matching method for types (<class 'str'>, <class 'str'>, <class 'str'>)


注意不支持关键字参数

In [78]:
try:
    Date(2012, 12, day=21)
except TypeError as e:
    print('TypeError:', e)

TypeError: __call__() got an unexpected keyword argument 'day'


对于继承有特殊情况的限制

In [79]:
class A:
    pass


class B(A):
    pass


class C:
    pass


class Spam(metaclass=MultipleMeta):

    def foo(self, x:A):
        print('Foo 1:', x)

    def foo(self, x:C):
        print('Foo 2:', x)

In [80]:
s = Spam()
a = A()
b = B()
c = C()

In [81]:
s.foo(a)

Foo 1: <__main__.A object at 0x11019f080>


In [82]:
s.foo(c)

Foo 2: <__main__.C object at 0x110185be0>


原因是因为 x:A 注解不能成功匹配子类实例（比如B的实例）

In [83]:
try:
    s.foo(b)
except TypeError as e:
    print('TypeError:', e)

TypeError: No matching method for types (<class '__main__.B'>,)


---
使用描述器的实现，同样也不支持关键字参数和有继承的限制

In [84]:
import types


class multimethod:

    def __init__(self, func):
        self._methods = {}
        self.__name__ = func.__name__
        self._default = func

    def match(self, *types):
        def register(func):
            ndefaults = len(func.__defaults__) if func.__defaults__ else 0
            for n in range(ndefaults+1):
                self._methods[types[:len(types) - n]] = func
            return self
        return register

    def __call__(self, *args):
        types = tuple(type(arg) for arg in args[1:])
        meth = self._methods.get(types, None)
        if meth:
            return meth(*args)
        else:
            return self._default(*args)

    def __get__(self, instance, cls):
        if instance is not None:
            return types.MethodType(self, instance)
        else:
            return self

In [85]:
class Spam:

    @multimethod
    def bar(self, *args):
        raise TypeError('No matching method for bar')

    @bar.match(int, int)
    def bar(self, x, y):
        print('Bar 1:', x, y)

    @bar.match(str, int)
    def bar(self, s, n = 0):
        print('Bar 2:', s, n)

In [86]:
s = Spam()

In [87]:
s.bar(2, 3)

Bar 1: 2 3


In [88]:
s.bar('hello')

Bar 2: hello 0


In [89]:
try:
    s.bar(2, 'hello')
except TypeError as e:
    print('TypeError:', e)

TypeError: No matching method for bar


---

### 避免重复的属性方法

检查属性的类型

In [90]:
from functools import partial


def typed_property(name, expected_type):
    storage_name = '_' + name

    @property
    def prop(self):
        return getattr(self, storage_name)

    @prop.setter
    def prop(self, value):
        if not isinstance(value, expected_type):
            raise TypeError('{} must be a {}'.format(name, expected_type))
        setattr(self, storage_name, value)

    return prop


class Person:
    name = typed_property('name', str)
    age = typed_property('age', int)

    def __init__(self, name, age):
        self.name = name
        self.age = age

In [91]:
Person('jim', 13)

<__main__.Person at 0x1101c9f60>

In [92]:
try:
    Person('jim', '13')
except TypeError as e:
    print('TypeError:', e)

TypeError: age must be a <class 'int'>


---

### 定义上下文管理器的简单方法

利用 `contextmanager` 装饰器，`yield` 之前的代码会作为 `__enter__` 方法执行，之后的代码会作为 `__exit__` 方法执行，如果出现了异常，异常会在 `yield` 语句那里抛出

In [93]:
import time
from contextlib import contextmanager

@contextmanager
def list_transaction(orig_list):
    working = list(orig_list)
    yield working
    orig_list[:] = working

In [94]:
items = [1, 2, 3]

with list_transaction(items) as working:
    working.append(4)
    working.append(5)

In [95]:
items

[1, 2, 3, 4, 5]

In [96]:
try:
    with list_transaction(items) as working:
        working.append(6)
        working.append(7)
        raise RuntimeError('oops')
except RuntimeError as e:
    print('RuntimeError:', e)

RuntimeError: oops


In [97]:
items

[1, 2, 3, 4, 5]

---

### 在局部变量域中执行代码

正常情况下 exec 的执行不会影响实际局部变量，因为 exec 获得是局部变量的拷贝

In [98]:
def test1():
    x = 0
    exec('x += 1')
    print(x)

In [99]:
test1()

0


想得到 exec 的结果，必须调用 `locals()`，来获取传递给 exec 局部变量的拷贝

In [100]:
def test2():
    x = 0
    loc = locals()
    print('before:', loc)
    exec('x += 1')
    print('after:', loc)
    print('loc_x =', loc['x'])

In [101]:
test2()

before: {'x': 0}
after: {'x': 1, 'loc': {...}}
loc_x = 1


不使用 locals 的替代方案

In [102]:
def test3():
    x = 0
    loc = {'x' : x }
    glb = {}
    exec('x += 1', glb, loc)
    a = loc['x']
    print(a)

In [103]:
test3()

1
