In [4]:
import torch
from torch import nn
from torch.nn import functional as F

net = nn.Sequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))

X = torch.rand(2, 20)
net(X)

tensor([[ 0.3004, -0.0015,  0.0696,  0.0112, -0.3014,  0.1669, -0.0074,  0.1977,
         -0.1973,  0.1037],
        [ 0.2911, -0.0655,  0.1491, -0.0371, -0.2153,  0.0720, -0.0199,  0.1473,
         -0.1231, -0.0095]], grad_fn=<AddmmBackward0>)

自定义块

In [None]:
class MLP(nn.Module):
    # 用模型参数声明层这里，我们声明两个全连接的层
    def __init__(self):
        # 调用MLP的父类Module的构造函数来执行必要的初始化
        # 这样，在类实例化时也可以指定其他函数参数，例如模型参数params
        super().__init__()
        self.hidden = nn.Linear(20, 256) # 隐藏层
        self.out = nn.Linear(256, 10)    # 输出层

    # 定义模型的前向传播，即如何根据输入x返回所需模型的输出
    def forward(self, X):  # nn.Module中实现了__call__方法，在实例被调用时，会自动执行操作并最终调用forward()方法
        # 注意，这里我们使用ReLU的函数版本，其在nn.functional模块中定义
        return self.out(F.relu(self.hidden(X)))

In [5]:
net = MLP()
net(X)

tensor([[-0.1129, -0.0212, -0.0575, -0.1617, -0.0322, -0.2607, -0.2167,  0.0974,
         -0.2039,  0.1847],
        [-0.0450, -0.0599, -0.1080, -0.1670, -0.0351, -0.1135, -0.1825, -0.0379,
         -0.1302,  0.1196]], grad_fn=<AddmmBackward0>)

顺序块

In [6]:
class MySequential(nn.Module):
    def __init__(self, *args):
        super().__init__()
        for idx, module in enumerate(args):
            # 这里，module是Module子类的一个实例我们把它保存在Module类的成员
            # 变量_modules中。_module的类型是OrderedDict
            self._modules[str(idx)] = module
        
    def forward(self, X):
        # OrderedDict保证了按照成员添加的顺序遍历它们
        for block in self._modules.values():
            X = block(X)
        return X

In [7]:
net = MySequential(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
net(X)

tensor([[-0.2490, -0.0816, -0.0283, -0.0777, -0.3367,  0.2357, -0.0926, -0.0827,
          0.0377,  0.3202],
        [-0.1729, -0.1149, -0.0616,  0.0312, -0.3250,  0.1550,  0.1133,  0.0240,
          0.1642,  0.2296]], grad_fn=<AddmmBackward0>)

In [11]:
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        # 不计算梯度的随机权重参数，因此其在训练期间保持不变
        self.rand_weight = torch.rand((20, 20), requires_grad=False)
        self.linear = nn.Linear(20, 20)
    
    def forward(self, X):
        X = self.linear(X)
        # 使用创建的常量函数以及relu和mm函数
        X = F.relu(torch.mm(X, self.rand_weight)+1)
        # 复用全连接层这相当于两个全连接层共享参数
        X = self.linear(X)
        # 控制流
        while X.abs().sum() > 1:
            X /= 2
        return X.sum()


In [13]:
net = FixedHiddenMLP()
net(X)

tensor(0.0631, grad_fn=<SumBackward0>)

可以混合搭配各种组合快

In [16]:
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(nn.Linear(20, 64), nn.ReLU(),
                                 nn.Linear(64, 32), nn.ReLU())
        self.linear = nn.Linear(32, 16)
    
    def forward(self, X):
        return self.linear(self.net(X))
    
chimera = nn.Sequential(NestMLP(), nn.Linear(16, 20), FixedHiddenMLP()) # chimera 假想的怪物
chimera(X)

tensor(0.1310, grad_fn=<SumBackward0>)

练习

将MySequential中存储块的方式更改为Python列表, 可以运行

In [17]:
class MySequential2(nn.Module):
    def __init__(self, *args):
        super().__init__()
        self.sequential = []
        for module in args:
            self.sequential.append(module)
        
    def forward(self, X):
        # OrderedDict保证了按照成员添加的顺序遍历它们
        for block in self.sequential:
            X = block(X)
        return X
    
net = MySequential2(nn.Linear(20, 256), nn.ReLU(), nn.Linear(256, 10))
net(X)

tensor([[ 0.0748, -0.1826,  0.0725,  0.1342,  0.2197,  0.0456,  0.0558,  0.1023,
         -0.1748, -0.0581],
        [ 0.1278, -0.2798, -0.1362,  0.2228,  0.0908,  0.0526, -0.0211,  0.0905,
         -0.1827, -0.0699]], grad_fn=<AddmmBackward0>)

实现平行块，以两个块为参数，如net1, net2, 并返回前向传播中两个网络的串联输出，这也被称为平行块。

In [None]:
net1 = nn.Sequential(nn.Linear(256, 20))
net2 = nn.Sequential(nn.Linear(256, 10))

class mergenet(nn.Module):
    def __init__(self):
        super().__init__()
        self.block1 = net1
        self.block2 = net2 
    
    def forward(self, X):
        return torch.cat([self.block1(X), self.block2(X)], 1)

X = torch.rand(2, 256)
net = mergenet()
net(X)

tensor([[ 2.8568e-01,  1.5830e-02, -1.6584e-01, -5.6773e-01,  3.6934e-01,
         -3.2824e-01,  3.4321e-01, -1.5514e-01, -8.7827e-02,  3.8812e-01,
          1.9380e-01,  4.2359e-01, -1.7288e-01, -9.5152e-01,  1.4269e-01,
          4.1053e-02,  9.4714e-02, -1.0936e-02, -5.2025e-01,  1.9197e-01,
         -2.4881e-01, -1.3404e-01, -3.6038e-01,  6.5159e-01,  9.4595e-02,
          4.0100e-01, -1.5314e-01,  1.6119e-01, -2.9076e-01,  4.7518e-01],
        [ 1.9477e-01,  5.8216e-01, -3.2636e-01, -4.7791e-01,  3.1160e-01,
         -1.3815e-01,  5.1588e-01, -1.2087e-01, -1.2493e-01,  1.6020e-01,
         -3.4580e-02,  3.1538e-01, -4.9339e-02, -6.4446e-01,  1.2129e-01,
          3.0762e-02,  2.8115e-01,  2.5723e-01, -2.2428e-01,  2.3390e-01,
         -6.3308e-01, -3.3962e-01, -3.1455e-01,  6.7073e-01,  1.4683e-01,
          6.6115e-01,  9.5396e-04,  2.4056e-04, -2.8462e-01,  8.1216e-01]],
       grad_fn=<CatBackward0>)

In [20]:
torch.rand(2, 256)

tensor([[0.0203, 0.5889, 0.1904, 0.8056, 0.2326, 0.9335, 0.3181, 0.2885, 0.4204,
         0.2408, 0.9753, 0.3264, 0.4023, 0.3978, 0.7310, 0.4318, 0.1090, 0.6453,
         0.7487, 0.2230, 0.9060, 0.8040, 0.0374, 0.0418, 0.9539, 0.2429, 0.6423,
         0.3163, 0.5829, 0.5764, 0.6383, 0.4113, 0.0980, 0.2311, 0.8769, 0.0660,
         0.7674, 0.2595, 0.5674, 0.6688, 0.8120, 0.5547, 0.3177, 0.3633, 0.9289,
         0.5007, 0.0990, 0.1641, 0.4705, 0.1240, 0.9538, 0.2681, 0.5460, 0.0818,
         0.4137, 0.6361, 0.5267, 0.5399, 0.6742, 0.2943, 0.9308, 0.9509, 0.3483,
         0.5792, 0.5896, 0.8851, 0.2023, 0.1578, 0.0939, 0.9470, 0.1008, 0.0281,
         0.9477, 0.1121, 0.8407, 0.6089, 0.9535, 0.5369, 0.4287, 0.7420, 0.6949,
         0.9552, 0.6196, 0.6838, 0.5749, 0.8095, 0.6305, 0.1335, 0.9491, 0.4728,
         0.9970, 0.7911, 0.5351, 0.7858, 0.2103, 0.8462, 0.2459, 0.5608, 0.2128,
         0.2208, 0.6296, 0.9951, 0.0032, 0.6167, 0.1951, 0.7064, 0.7845, 0.7224,
         0.3729, 0.3789, 0.1