In [1]:
import torch
import torch.nn.functional as F
from torch import nn


class CenteredLayer(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, X):
        return X - X.mean()

In [2]:
layer = CenteredLayer()
layer(torch.FloatTensor([1, 2, 3, 4, 5]))

tensor([-2., -1.,  0.,  1.,  2.])

In [3]:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())

In [12]:
Y = net(torch.rand(4, 8))
Y.mean()

tensor(5.5879e-09, grad_fn=<MeanBackward0>)

In [13]:
class MyLinear(nn.Module):
    def __init__(self, in_units, units):
        super().__init__()
        self.weight = nn.Parameter(torch.randn(in_units, units))
        self.bias = nn.Parameter(torch.randn(units,))
    def forward(self, X):
        linear = torch.matmul(X, self.weight.data) + self.bias.data
        return F.relu(linear)

In [14]:
linear = MyLinear(5, 3)
linear.weight

Parameter containing:
tensor([[ 1.4501, -0.6779, -1.0085],
        [ 0.6436, -1.1348,  1.2841],
        [ 0.7625, -1.0834,  0.3185],
        [ 2.2953,  1.5010,  1.0036],
        [ 0.1678, -0.0503,  0.5345]], requires_grad=True)

In [16]:
linear(torch.rand(2, 5))

tensor([[2.5102, 0.0000, 1.8611],
        [2.8845, 0.0000, 2.3553]])

In [17]:
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
net(torch.rand(2, 64))

tensor([[0.0000],
        [1.6926]])

named_parameters()的用法，因为其返回的是一个迭代器，所以访问是要在for循环中访问，并用两个参数来接收它

In [20]:
for name, param in net.named_parameters():
    print(name, param)

0.weight Parameter containing:
tensor([[ 0.6767, -0.9463,  0.0729, -0.7194,  0.0044,  2.9269,  0.6308,  0.0359],
        [-0.4147, -0.5932, -1.8727,  0.1398,  0.1084, -1.3612,  0.2159,  1.7722],
        [ 1.0150,  1.5477, -0.8887,  0.0078,  0.0247,  0.7180, -0.5159,  1.2353],
        [ 1.2624, -0.7854,  0.2429, -0.1433,  1.5152, -0.3669, -0.0358,  1.2142],
        [-0.0065,  0.0478,  0.3753,  0.6810, -0.1199,  2.0344, -1.8449, -0.3373],
        [ 0.3094,  0.6463,  0.4752, -0.2605,  0.9843, -0.7700, -1.4067, -1.6815],
        [-1.0596, -0.4426,  0.3403,  0.1628, -0.7078, -0.2570, -0.7920, -1.3645],
        [ 0.4612, -0.0172,  1.2038,  1.2423,  0.1893, -0.8509,  2.3262,  1.4448],
        [-0.4506,  0.6799,  0.4561, -0.3370, -1.3604, -1.3694, -1.1338,  1.2169],
        [ 0.4461,  1.1958, -1.8547, -0.9699, -0.0277, -0.0315, -2.1106, -0.6221],
        [ 0.1748, -1.4248,  0.5598, -1.1912,  0.1092, -0.4359, -0.0076, -0.4169],
        [ 0.0974,  0.9876,  0.6487,  0.2755,  1.2006,  1.4148,  0.3