#### 读取模型中的参数

创建一个模型

In [1]:
import torch


class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = torch.nn.Sequential(
            torch.nn.Linear(3, 4),
            torch.nn.Linear(4, 3),
        )
        self.layer2 = torch.nn.Linear(3, 6)

        self.layer3 = torch.nn.Sequential(
            torch.nn.Linear(6, 7),
            torch.nn.Linear(7, 5),
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x


net = MyModel()
print(net)





MyModel(
  (layer1): Sequential(
    (0): Linear(in_features=3, out_features=4, bias=True)
    (1): Linear(in_features=4, out_features=3, bias=True)
  )
  (layer2): Linear(in_features=3, out_features=6, bias=True)
  (layer3): Sequential(
    (0): Linear(in_features=6, out_features=7, bias=True)
    (1): Linear(in_features=7, out_features=5, bias=True)
  )
)


递归获取每一个模块

In [13]:

# -----------------------------------------------------------
# net.modules()、   net.named_modules()
# -----------------------------------------------------------
print("不带名字")
for i, layer in enumerate(net.modules()):
    print(i, " -> ", type(layer))

print("带名字")
i = 0
for name, layer in net.named_modules():
    print(i, " -> ", name, type(layer))
    i += 1

不带名字
0  ->  <class '__main__.MyModel'>
1  ->  <class 'torch.nn.modules.container.Sequential'>
2  ->  <class 'torch.nn.modules.linear.Linear'>
3  ->  <class 'torch.nn.modules.linear.Linear'>
4  ->  <class 'torch.nn.modules.linear.Linear'>
5  ->  <class 'torch.nn.modules.container.Sequential'>
6  ->  <class 'torch.nn.modules.linear.Linear'>
7  ->  <class 'torch.nn.modules.linear.Linear'>
带名字
0  ->   <class '__main__.MyModel'>
1  ->  layer1 <class 'torch.nn.modules.container.Sequential'>
2  ->  layer1.0 <class 'torch.nn.modules.linear.Linear'>
3  ->  layer1.1 <class 'torch.nn.modules.linear.Linear'>
4  ->  layer2 <class 'torch.nn.modules.linear.Linear'>
5  ->  layer3 <class 'torch.nn.modules.container.Sequential'>
6  ->  layer3.0 <class 'torch.nn.modules.linear.Linear'>
7  ->  layer3.1 <class 'torch.nn.modules.linear.Linear'>


只打印树的第一层节点

In [15]:

# -----------------------------------------------------------
# net.children()、   net.named_children()
# -----------------------------------------------------------
print("不带名字")
for layer in net.children():
    print(layer)

print("带名字")
for name, layer in net.named_children():
    print(name, layer)

不带名字
Sequential(
  (0): Linear(in_features=3, out_features=4, bias=True)
  (1): Linear(in_features=4, out_features=3, bias=True)
)
Linear(in_features=3, out_features=6, bias=True)
Sequential(
  (0): Linear(in_features=6, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=5, bias=True)
)
带名字
layer1 Sequential(
  (0): Linear(in_features=3, out_features=4, bias=True)
  (1): Linear(in_features=4, out_features=3, bias=True)
)
layer2 Linear(in_features=3, out_features=6, bias=True)
layer3 Sequential(
  (0): Linear(in_features=6, out_features=7, bias=True)
  (1): Linear(in_features=7, out_features=5, bias=True)
)


获取参数

In [18]:

# -----------------------------------------------------------
# net.parameters()、   net.named_parameters()
# -----------------------------------------------------------
print("不带名字")
for param in net.parameters():
    print(param.shape)

print("带名字")
for name, param in net.named_parameters():
    print(name, param.shape)

# -----------------------------------------------------------
# net.state_dict()
# -----------------------------------------------------------
print("也是带名字")
for key, value in net.state_dict().items():
    print(key, value.shape)

不带名字
torch.Size([4, 3])
torch.Size([4])
torch.Size([3, 4])
torch.Size([3])
torch.Size([6, 3])
torch.Size([6])
torch.Size([7, 6])
torch.Size([7])
torch.Size([5, 7])
torch.Size([5])
带名字
layer1.0.weight torch.Size([4, 3])
layer1.0.bias torch.Size([4])
layer1.1.weight torch.Size([3, 4])
layer1.1.bias torch.Size([3])
layer2.weight torch.Size([6, 3])
layer2.bias torch.Size([6])
layer3.0.weight torch.Size([7, 6])
layer3.0.bias torch.Size([7])
layer3.1.weight torch.Size([5, 7])
layer3.1.bias torch.Size([5])
也是带名字
layer1.0.weight torch.Size([4, 3])
layer1.0.bias torch.Size([4])
layer1.1.weight torch.Size([3, 4])
layer1.1.bias torch.Size([3])
layer2.weight torch.Size([6, 3])
layer2.bias torch.Size([6])
layer3.0.weight torch.Size([7, 6])
layer3.0.bias torch.Size([7])
layer3.1.weight torch.Size([5, 7])
layer3.1.bias torch.Size([5])
