In [14]:
# 自定义层
# 首先介绍一个不含模型参数的自定义层
import torch
from torch import nn

class CenteredLayer(nn.Module):
    def __init__(self,**kargs):
        super(CenteredLayer,self).__init__(**kargs)
        
    def forward(self,x):
        return x - x.mean()
    
# layer = CenteredLayer()
# layer(torch.tensor([1,2,3,4,5],dtype=torch.float))

net = nn.Sequential(
    nn.Linear(8,128),
    CenteredLayer(),
)
y = net(torch.rand(2,8))
y.mean().item()

-9.778887033462524e-09

In [35]:
# 自定义层
# 含有模型参数的自定义层
import torch
from torch import nn

class MyDense(nn.Module):
    def __init__(self):
        super(MyDense,self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(4,4)) for i in range(3)])
        self.params.append(nn.Parameter(torch.randn(4,1)))
        
    def forward(self,x):
        for i in range(len(self.params)):
            x = torch.mm(x,self.params[i])
        return x
net = MyDense()
# print(net)
# a = {'v':1}
# print(a['v'])

class MyDictDense(nn.Module):
    def __init__(self):
        super(MyDictDense,self).__init__()
        self.paramdict = nn.ParameterDict({
            'linear1':nn.Parameter(torch.randn(4,4)),
            'linear2':nn.Parameter(torch.randn(4,3))
        })
#     def forward(self,x):
#         for key in self.paramdict.keys():
#             x = torch.mm(x,self.paramdict[key])
#         return x
    
    def forward(self, x, choice='linear1'):
        return torch.mm(x, self.params[choice])

    
net2 = MyDictDense()
x = torch.randn(2,4)

net = nn.Sequential(
    MyDense(),
    MyDictDense(),
)
print(net)

Sequential(
  (0): MyDense(
    (params): ParameterList(
        (0): Parameter containing: [torch.FloatTensor of size 4x4]
        (1): Parameter containing: [torch.FloatTensor of size 4x4]
        (2): Parameter containing: [torch.FloatTensor of size 4x4]
        (3): Parameter containing: [torch.FloatTensor of size 4x1]
    )
  )
  (1): MyDictDense(
    (paramdict): ParameterDict(
        (linear1): Parameter containing: [torch.FloatTensor of size 4x4]
        (linear2): Parameter containing: [torch.FloatTensor of size 4x3]
    )
  )
)
