# 4.4 自定义层

In [1]:
import torch
from torch import nn

定义一个平均化计算的层，因为forward函数中只需要一个x，那么只需要一个x作为输入

In [2]:
class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super().__init__()

    def forward(self, x):
        return x - x.mean()


# 实例化这个层，然后做前向计算
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

调用该网络看看

In [3]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
print(net)
y = net(torch.rand(4, 8))
y.mean().item()

Sequential(
  (0): Linear(in_features=8, out_features=128, bias=True)
  (1): CenteredLayer()
)


-2.0954757928848267e-09

下面是ParameterList和ParameterDict的用法

In [4]:
# 除了像4.2.1节那样直接定义成Parameter类外，还可以使用ParameterList和ParameterDict分别定义参数的列表和字典。
class MyListDense(nn.Module):
    def __init__(self):
        super().__init__()
        # [1 for i in range(0, 3, 1)]结果就返回list[1, 1, 1]
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            # 复习：torch.mm是矩阵线代乘（正常乘）
            x = torch.mm(x, self.params[i])
        return x


net = MyListDense()
print(net)

MyListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [5]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)),
            'linear2': nn.Parameter(torch.randn(4, 1))
        })
        # python的dict操作，使用update将两个字典合在一起
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})  # 新增
        # 同dict操作，直接指定新键值再直接赋值即可往里添加
        self.params['linear4'] = nn.Parameter(torch.randn(4, 3))

    # 自行定义的前向计算，默认choice为linear1，返回x和linear1的乘积
    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])


net = MyDictDense()
print(net)

x = torch.ones(1, 4)
# 这里net后面是传入forward函数的变量，意思类似于第四行
# 默认是linear1
print(net(x, 'linear1'))  # 等价于print(net(x))
print(net(x, 'linear2'))
print(net(x, 'linear3'))
print(net(x=x, choice='linear4'))

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
      (linear4): Parameter containing: [torch.FloatTensor of size 4x3]
  )
)
tensor([[ 2.0170, -2.3575,  2.9080, -0.4367]], grad_fn=<MmBackward>)
tensor([[-3.5715]], grad_fn=<MmBackward>)
tensor([[0.5690, 1.6447]], grad_fn=<MmBackward>)
tensor([[0.2689, 0.0280, 0.7486]], grad_fn=<MmBackward>)


In [6]:
net = nn.Sequential(
    MyDictDense(),
    MyListDense(),
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
        (linear4): Parameter containing: [torch.FloatTensor of size 4x3]
    )
  )
  (1): MyListDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[-22.9917]], grad_fn=<MmBackward>)
