In [1]:
import torch
from torch import nn
from torch.nn import functional as F


In [2]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden=nn.Linear(20,256)
        self.out=nn.Linear(256,10)
    def forward(self,X):
        return self.out(F.relu(self.hidden(X)))

In [3]:
X=torch.rand(2,20)
net=MLP()
net(X)

tensor([[ 0.1489, -0.1984, -0.0550,  0.0403, -0.0409,  0.0919, -0.1565, -0.0189,
          0.0466, -0.2755],
        [ 0.1290, -0.0771, -0.0286, -0.0059, -0.1250,  0.1260, -0.0217,  0.0494,
          0.0099, -0.1919]], grad_fn=<AddmmBackward0>)

In [4]:
class MySequential(nn.Module):
    def __init__(self,*args):
        super().__init__()
        for block in args:
            self._modules[block]=block
    def forward(self,X):
        for block in self._modules.values():
            X=block(X)
        return X
net = MySequential(nn.Linear(20,256),nn.ReLU(),nn.Linear(256,10))
net(X)

tensor([[ 9.8187e-02, -1.9830e-01, -3.7636e-01, -1.5317e-01, -2.2701e-01,
          1.2153e-01, -3.7157e-03, -2.3273e-01, -1.8156e-02, -6.6392e-02],
        [ 1.3989e-01, -3.3146e-01, -2.0212e-01, -8.6904e-02, -9.7533e-02,
         -1.7201e-02,  1.5043e-04, -2.2386e-01,  1.4840e-01, -6.0031e-02]],
       grad_fn=<AddmmBackward0>)

在正向传播中执行代码：

In [5]:
class FixedHiddenMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_weight=torch.rand((20,20),requires_grad=False)
        self.linear=nn.Linear(20,20)
    def forward(self,X):
        X=self.linear(X)
        X=F.relu(torch.mm(X,self.rand_weight)+1)
        X=self.linear(X)
        while X.abs().sum() > 1:
            X/=2
        return X.sum()

net =FixedHiddenMLP()
net(X)

tensor(0.3996, grad_fn=<SumBackward0>)

In [9]:
class NestMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.net=nn.Sequential(nn.Linear(20,64),nn.ReLU(),nn.Linear(64,32))
        self.linear=nn.Linear(32,16)
    def forward(self,X):
        return self.linear(self.net(X))

chimera=nn.Sequential(NestMLP(),nn.Linear(16,20),FixedHiddenMLP())
chimera(X)

tensor(0.2514, grad_fn=<SumBackward0>)