## 自定义层
---

### 不含模型参数的自定义层

In [1]:
import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__()
    def forward(self, x):
        return x - x.mean()

layer = CenteredLayer()
layer(torch.tensor([1,2,3,4,5], dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

In [6]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
y = net(torch.rand(4, 8))
y

tensor([[-0.7948, -0.5475, -0.2065, -0.0209, -0.4834,  0.8275, -0.3894, -0.0943,
          0.3236,  0.5518, -0.1782,  0.3414, -0.9232,  0.2551,  0.3652,  0.4657,
          0.4744,  0.7016, -0.0055,  0.4329, -0.2925,  0.0590, -0.0853,  0.4360,
          0.5644, -0.2686,  0.7717,  0.1796,  0.7375,  0.0037, -0.4272,  0.4416,
         -0.1655,  0.7557,  0.7372,  0.3936,  0.1899,  0.6716, -0.4527, -0.0849,
          0.2648,  0.1903, -0.2491,  0.2744, -0.7896,  0.3739, -0.3029,  0.2557,
          0.8442,  0.0376, -0.0537,  0.2780, -0.8638, -0.6050, -0.0913,  0.0412,
         -0.4727, -0.5476,  0.1350,  0.2016,  0.3654,  0.0502, -0.1267,  0.1767,
          0.0373, -0.7600,  0.5537,  0.3082,  0.8094, -0.0383, -0.0692, -0.0272,
          0.1905, -0.6772, -0.0014,  0.7888, -0.9996, -0.3068,  0.1754,  0.5232,
          0.0705, -0.5180, -0.7525,  0.1491, -0.1703,  0.4155, -0.1160, -0.0334,
         -0.4525,  0.1529,  0.6055, -0.2037, -0.4627,  0.3659,  0.0407, -1.0556,
         -0.3826, -0.1823, -

In [9]:
y.mean().item()

-2.561137080192566e-09

### 含模型参数的自定义层

In [10]:
class MyDense(nn.Module):
    def __init__(self):
        super(MyDense, self).__init__()
        self.param = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
        self.param.append(nn.Parameter(torch.rand(4, 1)))
    def forward(self, x):
        for i in range(len(self.param)):
            x = torch.mm(x, self.param[i])
        return x
net = MyDense()
print(net)

MyDense(
  (param): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)


In [13]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn. ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)),
            'linear2': nn.Parameter(torch.randn(4, 1))
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))}) # 新增

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
print(net)

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)


In [14]:
x = torch.ones(1, 4)
print(net(x, 'linear1'))
print(net(x, 'linear2'))
print(net(x, 'linear3'))

tensor([[ 0.5632, -1.5351, -0.8762,  0.8945]], grad_fn=<MmBackward>)
tensor([[2.1383]], grad_fn=<MmBackward>)
tensor([[0.5623, 1.5582]], grad_fn=<MmBackward>)


In [15]:
net = nn.Sequential(
    MyDictDense(),
    MyDense()
)
print(net)
print(net(x))

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyDense(
    (param): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)
tensor([[25.0221]], grad_fn=<MmBackward>)
