In [1]:
import torch
import torch.nn.functional as F
from torch import nn

In [2]:
class CenterLayer(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, X):
        return X - X.mean() # mean是平均值

In [3]:
layer = CenterLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [4]:
net = nn.Sequential(nn.Linear(8, 128), CenterLayer())

Y = net(torch.rand(4, 8))
Y.mean()

tensor(-5.5879e-09, grad_fn=<MeanBackward0>)

# 带参数的层

In [5]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.zeros(units,))
        
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

dense = MyLinear(5, 3)
dense.weight, dense.bias

(Parameter containing:
 tensor([[-0.0435, -2.0116,  1.7458],
         [ 0.4108, -1.0088, -0.0713],
         [-2.4034, -1.7676,  0.5303],
         [-0.2744,  2.5604, -0.1749],
         [-0.3735, -0.4937, -1.9737]], requires_grad=True),
 Parameter containing:
 tensor([0., 0., 0.], requires_grad=True))

In [6]:
dense(torch.rand(2, 5))

tensor([[0.0000, 0.0000, 0.3220],
        [0.0000, 0.0000, 0.2517]])

In [7]:
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[1.1884],
        [0.0000]])