## x.1 Object Oriented Programming（OOP）面向对象程序设计

受到PyTorch Lightning等库的启发，我们期望能够创建多个类：Module; DataModule; Trainer

1. Module

- models
- losses
- optimization

2. DataModule

- data loaders for training and validation

3. Trainer

- train models on a variety of hardware platforms(GPUs, CPUs, parallel training, optimization algorithms)

## x.2 一些基础类

### x.2.1 `add_to_class`方法

首先我们需要在类中增加方法，不然每次更改类都会显得很麻烦。**我们在创建类后可以将函数注册为类中的方法**。

使用了python装饰器的知识，在调用的方法外层嵌套了一层函数。

In [12]:
def add_to_class(Class):
    """
    intro:
        Register functions as methods in created class.
    """
    def wrapper(obj):
        setattr(Class, obj.__name__, obj)
    return wrapper


接下来，我们用decorator来测试一下上面的代码

In [13]:
class A:
    def __init__(self) -> None:
        self.b = 1

a = A()

@add_to_class(A)
def do(self):
    print("Class attribute 'b' is ", self.b)

a.do()

Class attribute 'b' is  1


### x.2.2 `HyperParameters`类

继承HyperParameters的子类，能够在__init__方法中将所有参数保存为类属性

`raise NotImplemented` 意味着该方法在基类中只是一个占位符，必须要在子类中进行实现

但在下一行中定义了save_hyperparameters，这意味着第二个方法将第一个方法覆盖了，于是前面的方法便不会执行，并不是封装继承多态中的多态。

def save_hyperparameters

In [14]:
import inspect
class HyperParameters:
    """
    intro:
        The base class of hyperparameters.
    """
    def save_hyperparameters(self, ignore=[]):
        raise NotImplemented
    
    def save_hyperparameters(self, ignore=[]):
        """
        intro:
            Save function arguments into class attributes.
        """
        frame = inspect.currentframe().f_back
        _, _, _, local_vars = inspect.getargvalues(frame)
        self.hparams = {k:v for k, v in local_vars.items()
                        if k not in set(ignore+['self']) and not k.startswith('_')}
        for k, v in self.hparams.items():
            setattr(self, k, v)

In [15]:
class B(HyperParameters):
    def __init__(self, a, b, c):
        self.save_hyperparameters(ignore=['c'])
        print(self.a, self.b)
        print("There is no self.c = ", not hasattr(self, 'c'))

b = B(a=1, b=2, c=3)

1 2
There is no self.c =  True


### x.2.3 `ProgressBoard`类

ProgressBoard类模仿TensorBoard，能够在实验进行时交互式地绘制实验进度。



In [16]:
class ProgressBoard(HyperParameters):
    """
    intro:
        The board that plots data points in animation.
    """
    def __init__(self, xlabel=None, ylabel=None, xlim=None,
                 ylim=None, xscale='linear', yscale='linear',
                 ls=['-', '--', '-.', ':'], colors=['C0', 'C1', 'C2', 'C3'],
                 fig=None, axes=None, figsize=(3.5, 2.5), display=True):
        """
        intro:
            The information in plot.
        args:
            :param str xlabel:
        """
        self.save_hyperparameters()

    def draw(self, x, y, label, every_n=1):
        raise NotImplemented

## x.3 三种类

下面将详细介绍一下Module, DataModule, Trainer的抽象类

### x.3.1 Module

Module是所有模型的基类，我们需要完成三种必须方法和一种可选方法。

- `def forward`:  模型参数和数据的计算方式，如何对参数进行训练
- `def loss`:     损失函数
- `def configure_optimizers`: 书写优化函数，作用在于如何使用算法使得loss最小，且更新参数
- `training_step` accepts a data batch to return the loss value.
- `validation_step(option)` evaluation measures.

Module是`torch.nn.Module`的子类，它的优势在于，当你重写父类的`forward`方法后，会自动调用built-in `__call__` 方法。

In [17]:
import torch.nn as nn, torch
class Module(nn.Module, HyperParameters):
    """The base class of models."""
    def __init__(self, plot_train_per_epoch=2, plot_valid_per_epoch=1):
        super().__init__()
        self.save_hyperparameters()
        self.board = ProgressBoard()

    def loss(self, y_hat, y):
        raise NotImplementedError

    def forward(self, X):
        assert hasattr(self, 'net'), 'Neural network is defined'
        return self.net(X)

    def plot(self, key, value, train):
        """Plot a point in animation."""
        assert hasattr(self, 'trainer'), 'Trainer is not inited'
        self.board.xlabel = 'epoch'
        if train:
            x = self.trainer.train_batch_idx / \
                self.trainer.num_train_batches
            n = self.trainer.num_train_batches / \
                self.plot_train_per_epoch
        else:
            x = self.trainer.epoch + 1
            n = self.trainer.num_val_batches / \
                self.plot_valid_per_epoch
        self.board.draw(x, value.to(torch.device("cpu")).detach().numpy(),
                        ('train_' if train else 'val_') + key,
                        every_n=int(n))

    def training_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=True)
        return l

    def validation_step(self, batch):
        l = self.loss(self(*batch[:-1]), batch[-1])
        self.plot('loss', l, train=False)

    def configure_optimizers(self):
        raise NotImplementedError

### x.3.2 DataModule类

该类主要有两个必选方法和一个可选方法。

- `__init__` downloading the data and preprocessing the data.
- `train_dataloader` returns data loader for the training set.
- `val_dataloader(option)` returns data loader for the validaton set.

其中dataloader使用了python生成器方法

In [18]:
class DataModule(HyperParameters):
    """The base class of data."""
    def __init__(self, root='../data', num_workers=4):
        self.save_hyperparameters()

    def get_dataloader(self, train):
        raise NotImplementedError

    def train_dataloader(self):
        return self.get_dataloader(train=True)

    def val_dataloader(self):
        return self.get_dataloader(train=False)

### x.3.3 Trainer类

trainer类几乎不用改变，它最重要的就是fit方法

- `fit`: Module instance + DataModule instance + iterates till max_epochs.

In [19]:
class Trainer(HyperParameters):  #@save
    """The base class for training models with data."""
    def __init__(self, max_epochs, num_gpus=0, gradient_clip_val=0):
        self.save_hyperparameters()
        assert num_gpus == 0, 'No GPU support yet'

    def prepare_data(self, data):
        self.train_dataloader = data.train_dataloader()
        self.val_dataloader = data.val_dataloader()
        self.num_train_batches = len(self.train_dataloader)
        self.num_val_batches = (len(self.val_dataloader)
                                if self.val_dataloader is not None else 0)

    def prepare_model(self, model):
        model.trainer = self
        model.board.xlim = [0, self.max_epochs]
        self.model = model

    def fit(self, model, data):
        self.prepare_data(data)
        self.prepare_model(model)
        self.optim = model.configure_optimizers()
        self.epoch = 0
        self.train_batch_idx = 0
        self.val_batch_idx = 0
        for self.epoch in range(self.max_epochs):
            self.fit_epoch()

    def fit_epoch(self):
        raise NotImplementedError

下一章节我们将介绍具体的Module类和DataModule类和Trainer类.

最终，我们将必须会使用到的基础类增加到了`core.py`脚本中.