<a href="https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04206.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 4.1 模型构造

### 4.1.1 继承 Module 类来构造模型

In [0]:
import torch 
from torch import nn

class MLP(nn.Module):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256)
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)

In [3]:
net = MLP()
net

MLP(
  (hidden): Linear(in_features=784, out_features=256, bias=True)
  (act): ReLU()
  (output): Linear(in_features=256, out_features=10, bias=True)
)

In [4]:
X = torch.rand(2, 784)
net(X)

tensor([[ 0.0459,  0.1347, -0.3975,  0.0757, -0.1330, -0.0754,  0.0330,  0.1541,
          0.0051,  0.1180],
        [ 0.0429,  0.2412, -0.2639,  0.1379,  0.0189,  0.0011, -0.1205,  0.1794,
          0.0137,  0.2713]], grad_fn=<AddmmBackward>)

### 4.1.2 Module 的子类

#### 1. Sequential 类

In [0]:
from collections import OrderedDict 

class MySequential(nn.Module):
    def __init__(self, *args):
        super(MySequential, self).__init__()
        if len(args) == 1 and isinstance(args[0], OrderedDict):
            for key, module in args[0].items():
                self.add_module(key, module)
        else:
            for idx, module in enumerate(args):
                self.add_module(str(idx), module)

    
    def forward(self, input):
        for module in self._modules.values():
            input = module(input)
        return input 

In [6]:
net = MySequential(
    nn.Linear(784, 256), 
    nn.ReLU(), 
    nn.Linear(256, 10), 
)
net 

MySequential(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

In [7]:
net(X)

tensor([[-0.1612,  0.0973, -0.0472,  0.0649, -0.0788,  0.0777, -0.1005,  0.0287,
          0.2223,  0.0737],
        [-0.1212,  0.1584, -0.0714,  0.1273, -0.0960,  0.0352, -0.1249,  0.0433,
          0.1862, -0.1236]], grad_fn=<AddmmBackward>)

#### 2. ModuleList 类

In [8]:
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10))
net

ModuleList(
  (0): Linear(in_features=784, out_features=256, bias=True)
  (1): ReLU()
  (2): Linear(in_features=256, out_features=10, bias=True)
)

In [9]:
net[-1]

Linear(in_features=256, out_features=10, bias=True)

In [0]:
class Module_ModuleList(nn.Module):
    def __init__(self):
        super(Module_ModuleList, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(10, 10)])


class Module_List(nn.Module):
    def __init__(self):
        super(Module_List, self).__init__()
        self.linears = [nn.Linear(10, 10)] 

In [11]:
net1 = Module_ModuleList()

print(net1)

for p in net1.parameters():
    print(p.size())

Module_ModuleList(
  (linears): ModuleList(
    (0): Linear(in_features=10, out_features=10, bias=True)
  )
)
torch.Size([10, 10])
torch.Size([10])


In [12]:
net2 = Module_List()

print(net2)

for p in net2.parameters():
    print(p.size())

Module_List()


#### 3. ModuleDict 类

In [13]:
net = nn.ModuleDict({
    'linear': nn.Linear(784, 256), 
    'act': nn.ReLU(), 
})
net['output'] = nn.Linear(256, 10)
net['linear']

Linear(in_features=784, out_features=256, bias=True)

In [14]:
net.output

Linear(in_features=256, out_features=10, bias=True)

In [15]:
net

ModuleDict(
  (act): ReLU()
  (linear): Linear(in_features=784, out_features=256, bias=True)
  (output): Linear(in_features=256, out_features=10, bias=True)
)

### 4.1.3 构造复杂的模型

In [0]:
class FancyMLP(nn.Module):
    def __init__(self, **kwargs):
        super(FancyMLP, self).__init__(**kwargs)

        self.rand_weight = torch.rand((20, 20), requires_grad=False)
        self.linear = nn.Linear(20, 20)

    def forward(self, x):
        x = self.linear(x)
        x = nn.functional.relu(torch.mm(x, self.rand_weight.data) + 1)
        x = self.linear(x)
        while x.norm().item() > 1:
            x /= 2 
        if x.norm().item() < 0.8:
            x *= 10 
        return x.sum()

In [17]:
net = FancyMLP()
net 

FancyMLP(
  (linear): Linear(in_features=20, out_features=20, bias=True)
)

In [18]:
X = torch.rand(2, 20)
net(X)

tensor(-13.4920, grad_fn=<SumBackward0>)

In [0]:
class NestMLP(nn.Module):
    def __init__(self, **kwargs):
        super(NestMLP, self).__init__(**kwargs)
        self.net = nn.Sequential(nn.Linear(40, 30), nn.ReLU())

    def forward(self, x):
        return self.net(x)

In [23]:
net = nn.Sequential(NestMLP(), nn.Linear(30, 20), FancyMLP())
net

Sequential(
  (0): NestMLP(
    (net): Sequential(
      (0): Linear(in_features=40, out_features=30, bias=True)
      (1): ReLU()
    )
  )
  (1): Linear(in_features=30, out_features=20, bias=True)
  (2): FancyMLP(
    (linear): Linear(in_features=20, out_features=20, bias=True)
  )
)

In [24]:
X = torch.rand(2, 40)
net(X)

tensor(10.6316, grad_fn=<SumBackward0>)