#### 张量的保存和加载

In [None]:
import torch

a = torch.rand(10)
print(a)

In [None]:
import os
os.makedirs("model", exist_ok=True)
torch.save(a, 'model/tensor-a')

In [None]:
torch.load('model/tensor-a')

In [None]:
a = torch.rand(10)
b = torch.rand(10)
c = torch.rand(10)
torch.save([a,b,c], 'model/tensor-abc')

In [None]:
torch.load('model/tensor-abc')

In [None]:
a = torch.rand(10)
b = torch.rand(10)
c = torch.rand(10)
tensor_dict={'a':a, 'b':b, 'c':c}
torch.save(tensor_dict, 'model/tensor_dict')

In [None]:
torch.load('model/tensor_dict')

#### 模型参数的保存和加载

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms

# 定义 MLP 网络
class MLP(nn.Module):
    def __init__(self, input_size, hidden_size, num_classes):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, num_classes)
    
    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        out = self.relu(out)
        out = self.fc3(out)
        return out

# 定义超参数
input_size = 28 * 28  # 输入大小
hidden_size = 512  # 隐藏层大小
num_classes = 10  # 输出大小（类别数）

In [None]:
# 实例化 MLP 网络
model = MLP(input_size, hidden_size, num_classes)
X = torch.randn(size=(2, 28*28))

In [None]:
torch.save(model.state_dict(), 'model/mlp.params')

In [None]:
params = torch.load('model/mlp.params')
model_load = MLP(input_size, hidden_size, num_classes)
model_load.load_state_dict(params)

In [12]:
output1 = model(X)
print(output1)

tensor([[-0.0107,  0.0414,  0.0170, -0.0564,  0.1039, -0.0627, -0.0256,  0.1142,
         -0.1233,  0.1592],
        [-0.0226, -0.0832, -0.0352, -0.1938,  0.0435, -0.0203,  0.0838,  0.0771,
         -0.2488,  0.2506]], grad_fn=<AddmmBackward0>)


In [13]:
output2 = model_load(X)
print(output2)

tensor([[-0.0107,  0.0414,  0.0170, -0.0564,  0.1039, -0.0627, -0.0256,  0.1142,
         -0.1233,  0.1592],
        [-0.0226, -0.0832, -0.0352, -0.1938,  0.0435, -0.0203,  0.0838,  0.0771,
         -0.2488,  0.2506]], grad_fn=<AddmmBackward0>)
