<a href="https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/04204.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## 4.4 自定义层

### 4.4.1 不含模型参数的自定义层

In [0]:
import torch
from torch import nn 

class CenteredLayer(nn.Module):
    def __init__(self, **kwargs):
        super(CenteredLayer, self).__init__(**kwargs)

    def forward(self, x):
        return x - x.mean()

In [3]:
layer = CenteredLayer()
layer(torch.tensor([1, 2, 3, 4, 5], dtype=torch.float))

tensor([-2., -1.,  0.,  1.,  2.])

In [0]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [5]:
y = net(torch.rand(4, 8))
y.mean().item()

4.656612873077393e-10

### 4.4.2 含模型参数的自定义层

In [13]:
class MyListDense(nn.Module):
    def __init__(self):
        super(MyListDense, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4, 4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4, 1)))

    def forward(self, x):
        for i in range(len(self.params)):
            x = torch.mm(x, self.params[i])
        return x 

net = MyListDense()
net

MyListDense(
  (params): ParameterList(
      (0): Parameter containing: [torch.FloatTensor of size 4x4]
      (1): Parameter containing: [torch.FloatTensor of size 4x4]
      (2): Parameter containing: [torch.FloatTensor of size 4x4]
      (3): Parameter containing: [torch.FloatTensor of size 4x1]
  )
)

In [14]:
class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense, self).__init__()
        self.params = nn.ParameterDict({
            'linear1': nn.Parameter(torch.randn(4, 4)), 
            'linear2': nn.Parameter(torch.randn(4, 1)),
        })
        self.params.update({'linear3': nn.Parameter(torch.randn(4, 2))})

    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

net = MyDictDense()
net

MyDictDense(
  (params): ParameterDict(
      (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
      (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
      (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
  )
)

In [15]:
x = torch.ones(1, 4)
net(x, 'linear1')

tensor([[ 0.0299, -0.1434, -1.7719,  2.3091]], grad_fn=<MmBackward>)

In [16]:
net(x, 'linear2')

tensor([[-1.7349]], grad_fn=<MmBackward>)

In [17]:
net(x, 'linear3')

tensor([[-1.9762,  1.9643]], grad_fn=<MmBackward>)

In [18]:
net = nn.Sequential(
    MyDictDense(), 
    MyListDense(), 
)

net 

Sequential(
  (0): MyDictDense(
    (params): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x1]
        (linear3): Parameter containing: [torch.FloatTensor of size 4x2]
    )
  )
  (1): MyListDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
)

In [19]:
net(x)

tensor([[-19.8875]], grad_fn=<MmBackward>)