# 类与对象
---

### 自定义类实例的字符串显示

In [1]:
class Date:

    _formats = {
    'ymd' : '{d.year}-{d.month}-{d.day}',
    'mdy' : '{d.month}/{d.day}/{d.year}',
    'dmy' : '{d.day}/{d.month}/{d.year}'
    }

    def __init__(self, year, month, day):
        self.year = year
        self.month = month
        self.day = day

    def __repr__(self):
        return 'Pair(year={0.year!r}, month={0.month!r}, day={0.day!r})'.format(self)

    def __str__(self):
        return '({0.year!s}, {0.month!s}, {0.day!s})'.format(self)

    def __format__(self, code=''):
        if code == '':
            code = 'ymd'
        fmt = self._formats[code]
        return fmt.format(d=self)

In [2]:
d = Date(2012, 12, 21)

In [3]:
d    # __repr__() output

Pair(year=2012, month=12, day=21)

In [4]:
print(d)    # __str__() output

(2012, 12, 21)


In [5]:
format(d)    # __format__() output

'2012-12-21'

---

### 让对象支持上下文管理协议

In [6]:
from socket import socket, AF_INET, SOCK_STREAM

class LazyConnection:

    def __init__(self, address, family=AF_INET, type=SOCK_STREAM):
        self.address = address
        self.family = family
        self.type = type
        self.sock = None

    def __enter__(self):
        if self.sock is not None:
            raise RuntimeError('Already connected')
        self.sock = socket(self.family, self.type)
        self.sock.connect(self.address)
        return self.sock

    def __exit__(self, exc_ty, exc_val, tb):
        self.sock.close()
        self.sock = None

In [7]:
from functools import partial

conn = LazyConnection(('httpbin.org', 80))

with conn as s:
    s.send(b'GET /ip HTTP/1.0\r\n')
    s.send(b'Host: httpbin.org\r\n')
    s.send(b'\r\n')
    resp = b''.join(iter(partial(s.recv, 8192), b''))
    print(resp)

b'HTTP/1.1 200 OK\r\nConnection: close\r\nServer: gunicorn/19.9.0\r\nDate: Wed, 31 Oct 2018 12:12:37 GMT\r\nContent-Type: application/json\r\nContent-Length: 33\r\nAccess-Control-Allow-Origin: *\r\nAccess-Control-Allow-Credentials: true\r\nVia: 1.1 vegur\r\n\r\n{\n  "origin": "115.205.68.100"\n}\n'


---

### 创建大量对象时节省内存方法

当定义 `__slots__` 后，Python 就会为实例使用一种更加紧凑的内部表示。实例通过一个很小的固定大小的数组来构建，而不是为每个实例定义一个字典，这跟元组或列表很类似。

In [8]:
class Date:
    __slots__ = ['year', 'month', 'day']

    def __init__(self, year, month, day):
        self.year = year
        self.month = month
        self.day = day

---

### 简化数据结构的初始化

In [9]:
import math


class Structure:
    """
    该类及子类在生成实例时，会自动将 _fields 中的字段加入 __init__
    不在 _fields 的字段，也可以用关键词参数加入
    """
    _fields = []

    def __init__(self, *args, **kwargs):
        if len(args) != len(self._fields):
            raise TypeError('Expected {} arguments'.format(len(self._fields)))

        for name, value in zip(self._fields, args):
            setattr(self, name, value)

        extra_args = kwargs.keys() - self._fields
        for name in extra_args:
            setattr(self, name, kwargs.pop(name))

        if kwargs:
            raise TypeError('Duplicate values for {}'.format(','.join(kwargs)))


class Stock(Structure):
    _fields = ['name', 'shares', 'price']

In [10]:
s = Stock('ACME', 50, 91.1, date='2012-10-1')

In [11]:
s.name, s.shares, s.price, s.date

('ACME', 50, 91.1, '2012-10-1')

---

### 定义接口或者抽象基类

抽象类能确保子类实现了抽象方法，抽象类本身无法被实例化

In [12]:
from abc import ABCMeta, abstractmethod

class IStream(metaclass=ABCMeta):

    @abstractmethod
    def read(self, maxbytes=-1):
        pass

    @abstractmethod
    def write(self, data):
        pass

除了继承抽象类的方式，也可以使用 register 方法

In [13]:
import io
IStream.register(io.IOBase)

io.IOBase

---

### 实现数据模型的类型约束

In [14]:
class Descriptor:

    def __init__(self, name=None, **opts):
        self.name = name
        for key, value in opts.items():
            setattr(self, key, value)

    def __set__(self, instance, value):
        instance.__dict__[self.name] = value


def Typed(expected_type, cls=None):

    if cls is None:
        return lambda cls: Typed(expected_type, cls)
    super_set = cls.__set__

    def __set__(self, instance, value):
        if not isinstance(value, expected_type):
            raise TypeError('expected ' + str(expected_type))
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


def Unsigned(cls):

    super_set = cls.__set__

    def __set__(self, instance, value):
        if value < 0:
            raise ValueError('Expected must be >= 0')
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


def MaxSized(cls):

    super_init = cls.__init__

    def __init__(self, name=None, **opts):
        if 'size' not in opts:
            raise TypeError('missing size option')
        super_init(self, name, **opts)

    cls.__init__ = __init__

    super_set = cls.__set__

    def __set__(self, instance, value):
        if len(value) >= self.size:
            raise ValueError('size must be < ' + str(self.size))
        super_set(self, instance, value)

    cls.__set__ = __set__
    return cls


@Typed(int)
class Integer(Descriptor):
    pass


@Unsigned
class UnsignedInteger(Integer):
    pass


@Typed(float)
class Float(Descriptor):
    pass


@Unsigned
class UnsignedFloat(Float):
    pass


@Typed(str)
class String(Descriptor):
    pass


@MaxSized
class SizedString(String):
    pass

In [15]:
class Stock:

    name = SizedString('name', size=8)
    shares = UnsignedInteger('shares')
    price = UnsignedFloat('price')

    def __init__(self, name, shares, price):
        self.name = name
        self.shares = shares
        self.price = price

In [16]:
try:
    Stock('TooLongName', 50, 91.1)
except ValueError as e:
    print('ValueError:', e)

ValueError: size must be < 8


In [17]:
try:
    Stock('ACME', 50, -21.5)
except ValueError as e:
    print('ValueError:', e)

ValueError: Expected must be >= 0


In [18]:
s = Stock('ACME', 50, 91.1)

In [19]:
s.name, s.shares, s.price

('ACME', 50, 91.1)

---

### 实现自定义容器

In [20]:
import collections

class SortedItems(collections.Sequence):
    """ 有序容器，支持集合的所有方法 """

    def __init__(self, initial=None):
        self._items = sorted(initial) if initial is not None else []

    def __getitem__(self, index):
        return self._items[index]

    def __len__(self):
        return len(self._items)

    def add(self, item):
        bisect.insort(self._items, item)


class Items(collections.MutableSequence):
    """ 列表容器，支持列表的所有方法 """

    def __init__(self, initial=None):
        self._items = list(initial) if initial is not None else []

    def __getitem__(self, index):
        print('Getting:', index)
        return self._items[index]

    def __setitem__(self, index, value):
        print('Setting:', index, value)
        self._items[index] = value

    def __delitem__(self, index):
        print('Deleting:', index)
        del self._items[index]

    def insert(self, index, value):
        print('Inserting:', index, value)
        self._items.insert(index, value)

    def __len__(self):
        print('Len')
        return len(self._items)

---

### 绕过 \__init__ 方法创建实例

In [21]:
from time import localtime

class Date:

    def __init__(self, year, month, day):
        self.year = year
        self.month = month
        self.day = day

    @classmethod
    def today(cls):
        d = cls.__new__(cls)
        t = localtime()
        d.year = t.tm_year
        d.month = t.tm_mon
        d.day = t.tm_mday
        return d

In [22]:
d = Date.today()

In [23]:
d.year, d.month, d.day

(2018, 10, 31)

---

### 利用Mixins扩展类功能

当想扩展其他类的功能，又不想继承类的时候，可以定义混入类  
方法一，使用多继承

In [24]:
class LoggedMappingMixin:
    # 混入类没有实例变量，因为直接实例化混入类没有任何意义
    __slots__ = ()

    def __getitem__(self, key):
        print('Getting ' + str(key))
        return super().__getitem__(key)

    def __setitem__(self, key, value):
        print('Setting {} = {!r}'.format(key, value))
        return super().__setitem__(key, value)

    def __delitem__(self, key):
        print('Deleting ' + str(key))
        return super().__delitem__(key)


class SetOnceMappingMixin:
    __slots__ = ()

    def __setitem__(self, key, value):
        if key in self:
            raise KeyError(str(key) + ' already set')
        return super().__setitem__(key, value)


class StringKeysMappingMixin:
    __slots__ = ()

    def __setitem__(self, key, value):
        if not isinstance(key, str):
            raise TypeError('keys must be strings')
        return super().__setitem__(key, value)

In [25]:
class LoggedDict(LoggedMappingMixin, dict):
    pass

In [26]:
d = LoggedDict()

In [27]:
d['x'] = 23

Setting x = 23


In [28]:
d['x']

Getting x


23

In [29]:
del d['x']

Deleting x


方法二，使用类装饰器的实现

In [30]:
def LoggedMapping(cls):

    # 这里必须先赋值到变量，否则会无限递归
    cls_getitem = cls.__getitem__
    cls_setitem = cls.__setitem__
    cls_delitem = cls.__delitem__

    def __getitem__(self, key):
        print('Getting ' + str(key))
        return cls_getitem(self, key)

    def __setitem__(self, key, value):
        print('Setting {} = {!r}'.format(key, value))
        return cls_setitem(self, key, value)

    def __delitem__(self, key):
        print('Deleting ' + str(key))
        return cls_delitem(self, key)

    cls.__getitem__ = __getitem__
    cls.__setitem__ = __setitem__
    cls.__delitem__ = __delitem__
    return cls

In [31]:
@LoggedMapping
class LoggedDict(dict):
    pass

In [32]:
d = LoggedDict()

In [33]:
d['x'] = 13

Setting x = 13


In [34]:
d['x']

Getting x


13

---

### 实现状态对象或者状态机

普通方案，缺点是条件判断过多，执行效率低下，不容易扩展

In [35]:
class Connection:

    def __init__(self):
        self.state = 'CLOSED'

    def read(self):
        if self.state != 'OPEN':
            raise RuntimeError('Not open')
        print('reading')

    def write(self, data):
        if self.state != 'OPEN':
            raise RuntimeError('Not open')
        print('writing')

    def open(self):
        if self.state == 'OPEN':
            raise RuntimeError('Already open')
        self.state = 'OPEN'

    def close(self):
        if self.state == 'CLOSED':
            raise RuntimeError('Already closed')
        self.state = 'CLOSED'

新方案，对每个状态定义一个类，并相互切换

In [36]:
class Connection:

    def __init__(self):
        self.new_state(ClosedConnectionState)

    def new_state(self, newstate):
        self._state = newstate

    def read(self):
        return self._state.read(self)

    def write(self, data):
        return self._state.write(self, data)

    def open(self):
        return self._state.open(self)

    def close(self):
        return self._state.close(self)


class ConnectionState:

    @staticmethod
    def read(conn):
        raise NotImplementedError()

    @staticmethod
    def write(conn, data):
        raise NotImplementedError()

    @staticmethod
    def open(conn):
        raise NotImplementedError()

    @staticmethod
    def close(conn):
        raise NotImplementedError()


class ClosedConnectionState(ConnectionState):

    @staticmethod
    def read(conn):
        raise RuntimeError('Not open')

    @staticmethod
    def write(conn, data):
        raise RuntimeError('Not open')

    @staticmethod
    def open(conn):
        conn.new_state(OpenConnectionState)

    @staticmethod
    def close(conn):
        raise RuntimeError('Already closed')


class OpenConnectionState(ConnectionState):

    @staticmethod
    def read(conn):
        print('reading')

    @staticmethod
    def write(conn, data):
        print('writing')

    @staticmethod
    def open(conn):
        raise RuntimeError('Already open')

    @staticmethod
    def close(conn):
        conn.new_state(ClosedConnectionState)

---

### 通过字符串调用对象方法

In [37]:
import math

class Point:

    def __init__(self, x, y):
        self.x = x
        self.y = y

    def __repr__(self):
        return 'Point({!r:},{!r:})'.format(self.x, self.y)

    def distance(self, x, y):
        return math.hypot(self.x - x, self.y - y)

方法一

In [38]:
getattr(Point(2, 3), 'distance')(0, 0)

3.605551275463989

方法二，适用于通过相同参数多次调用某个方法

In [39]:
import operator

points = [
    Point(1, 2),
    Point(3, 0),
    Point(10, -3),
    Point(-5, -7),
    Point(-1, 8),
    Point(3, 2)
]

points.sort(key=operator.methodcaller('distance', 0, 0))

In [40]:
points

[Point(1,2), Point(3,0), Point(3,2), Point(-1,8), Point(-5,-7), Point(10,-3)]

---

### 实现访问者模式

假设写一个表示数学表达式的程序

In [41]:
class Node:
    pass

class UnaryOperator(Node):
    def __init__(self, operand):
        self.operand = operand

class BinaryOperator(Node):
    def __init__(self, left, right):
        self.left = left
        self.right = right

class Add(BinaryOperator):
    pass

class Sub(BinaryOperator):
    pass

class Mul(BinaryOperator):
    pass

class Div(BinaryOperator):
    pass

class Negate(UnaryOperator):
    pass

class Number(Node):
    def __init__(self, value):
        self.value = value

In [42]:
# 数学表达式：1 + 2 * (3 - 4) / 5
t1 = Sub(Number(3), Number(4))
t2 = Mul(Number(2), t1)
t3 = Div(t2, Number(5))
t4 = Add(Number(1), t3)

---
利用递归，实现访问者模式

In [43]:
class NodeVisitor:

    def visit(self, node):
        methname = 'visit_' + type(node).__name__
        meth = getattr(self, methname, None)
        if meth is None:
            meth = self.generic_visit
        return meth(node)

    def generic_visit(self, node):
        raise RuntimeError('No {} method'.format(
            'visit_' + type(node).__name__))


class Evaluator(NodeVisitor):
    """ 定义一个类对数学表达式求值 """

    def visit_Number(self, node):
        return node.value

    def visit_Add(self, node):
        return self.visit(node.left) + self.visit(node.right)

    def visit_Sub(self, node):
        return self.visit(node.left) - self.visit(node.right)

    def visit_Mul(self, node):
        return self.visit(node.left) * self.visit(node.right)

    def visit_Div(self, node):
        return self.visit(node.left) / self.visit(node.right)

    def visit_Negate(self, node):
        return -node.operand


class StackCode(NodeVisitor):
    """ 定义一个类在一个栈上面将一个表达式转换成多个操作序列 """

    def generate_code(self, node):
        self.instructions = []
        self.visit(node)
        return self.instructions

    def visit_Number(self, node):
        self.instructions.append(('PUSH', node.value))

    def binop(self, node, instruction):
        self.visit(node.left)
        self.visit(node.right)
        self.instructions.append((instruction,))

    def visit_Add(self, node):
        self.binop(node, 'ADD')

    def visit_Sub(self, node):
        self.binop(node, 'SUB')

    def visit_Mul(self, node):
        self.binop(node, 'MUL')

    def visit_Div(self, node):
        self.binop(node, 'DIV')

    def unaryop(self, node, instruction):
        self.visit(node.operand)
        self.instructions.append((instruction,))

    def visit_Negate(self, node):
        self.unaryop(node, 'NEG')

In [44]:
e = Evaluator()
e.visit(t4)

0.6

In [45]:
s = StackCode()
s.generate_code(t4)

[('PUSH', 1),
 ('PUSH', 2),
 ('PUSH', 3),
 ('PUSH', 4),
 ('SUB',),
 ('MUL',),
 ('PUSH', 5),
 ('DIV',),
 ('ADD',)]

---
不使用递归的实现方式

In [46]:
import types


class NodeVisitor:

    def visit(self, node):
        stack = [node]
        last_result = None
        while stack:
            try:
                last = stack[-1]
                if isinstance(last, types.GeneratorType):
                    stack.append(last.send(last_result))
                    last_result = None
                elif isinstance(last, Node):
                    stack.append(self._visit(stack.pop()))
                else:
                    last_result = stack.pop()
            except StopIteration:
                stack.pop()

        return last_result

    def _visit(self, node):
        methname = 'visit_' + type(node).__name__
        meth = getattr(self, methname, None)
        if meth is None:
            meth = self.generic_visit
        return meth(node)

    def generic_visit(self, node):
        raise RuntimeError('No {} method'.format(
            'visit_' + type(node).__name__))


class Evaluator(NodeVisitor):

    def visit_Number(self, node):
        return node.value

    def visit_Add(self, node):
        yield (yield node.left) + (yield node.right)

    def visit_Sub(self, node):
        yield (yield node.left) - (yield node.right)

    def visit_Mul(self, node):
        yield (yield node.left) * (yield node.right)

    def visit_Div(self, node):
        yield (yield node.left) / (yield node.right)

    def visit_Negate(self, node):
        yield - (yield node.operand)

In [47]:
a = Number(0)
for n in range(1,10000):
    a = Add(a, Number(n))

e = Evaluator()
e.visit(a)

49995000

---

### 实现类的比较操作

通常定义 `__eq__` 和 `__lt__` 两个方法，然后搭配使用 total_ordering 装饰器，  
如果不使用，不能支持 <= 和 >=

In [48]:
from functools import total_ordering


class Room:

    def __init__(self, name, length, width):
        self.name = name
        self.length = length
        self.width = width
        self.square_feet = self.length * self.width


@total_ordering
class House:

    def __init__(self, name, style):
        self.name = name
        self.style = style
        self.rooms = list()

    @property
    def living_space_footage(self):
        return sum(r.square_feet for r in self.rooms)

    def add_room(self, room):
        self.rooms.append(room)

    def __str__(self):
        return '{}: {} square foot {}'.format(self.name,
                self.living_space_footage,
                self.style)

    def __eq__(self, other):
        return self.living_space_footage == other.living_space_footage

    def __lt__(self, other):
        return self.living_space_footage < other.living_space_footage

In [49]:
h1 = House('h1', 'Cape')
h1.add_room(Room('Master Bedroom', 14, 21))
h1.add_room(Room('Living Room', 18, 20))
h1.add_room(Room('Kitchen', 12, 16))
h1.add_room(Room('Office', 12, 12))
h2 = House('h2', 'Ranch')
h2.add_room(Room('Master Bedroom', 14, 21))
h2.add_room(Room('Living Room', 18, 20))
h2.add_room(Room('Kitchen', 12, 16))
h3 = House('h3', 'Split')
h3.add_room(Room('Master Bedroom', 14, 21))
h3.add_room(Room('Living Room', 18, 20))
h3.add_room(Room('Office', 12, 16))
h3.add_room(Room('Kitchen', 15, 17))
houses = [h1, h2, h3]
print('Is h1 bigger than h2?', h1 > h2)
print('Is h2 smaller than h3?', h2 < h3)
print('Is h2 greater than or equal to h1?', h2 >= h1)
print('Which one is biggest?', max(houses))
print('Which is smallest?', min(houses))

Is h1 bigger than h2? True
Is h2 smaller than h3? True
Is h2 greater than or equal to h1? False
Which one is biggest? h3: 1101 square foot Split
Which is smallest? h2: 846 square foot Ranch


---

### 实现单例工厂

In [50]:
import weakref


class CachedSpamManager:

    def __init__(self):
        self._cache = weakref.WeakValueDictionary()

    def get_spam(self, name):
        if name not in self._cache:
            temp = Spam._new(name)
            self._cache[name] = temp
        else:
            temp = self._cache[name]
        return temp

    def clear(self):
        self._cache.clear()


class Spam:

    def __init__(self, *args, **kwargs):
        raise RuntimeError("Can't instantiate directly")

    @classmethod
    def _new(cls, name):
        self = cls.__new__(cls)
        self.name = name
        return self

In [51]:
csm = CachedSpamManager()

In [52]:
a = csm.get_spam('foo')
b = csm.get_spam('foo')
c = csm.get_spam('bar')

In [53]:
a is b

True

In [54]:
a is c

False