# 保存模型

In [1]:
import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

# save method 1 : 保存模型方法1，该方法保存的是模型结构以及模型参数
torch.save(vgg16, "vgg16_method01.pth")

# save method 2 : 保存模型方法2, 该方法只保存模型参数（官方推荐使用）
torch.save(vgg16.state_dict(), "vgg16_method02.pth")



# 加载模型

In [2]:
import torch
import torchvision
 
# loda method 1 ：对应保存模型方法1
vgg16_method01 = torch.load("vgg16_method01.pth")
print(f"vgg16_method01 : {vgg16_method01}")

# loda method 2 ：对应保存模型方法2
vgg16_method02 = torchvision.models.vgg16(pretrained=False)
vgg16_method02.load_state_dict(torch.load("vgg16_method02.pth"))
print(f"vgg16_method02 : {vgg16_method02}")

  vgg16_method01 = torch.load("vgg16_method01.pth")


vgg16_method01 : VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padd

  vgg16_method02.load_state_dict(torch.load("vgg16_method02.pth"))


vgg16_method02 : VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padd

**使用mothod1 保存和加载模型存在的陷阱**

In [3]:
# 自定义模型
import torch.nn as nn
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x
    
myModule = MyModule()
# 使用mothod1 保存模型,在另一个类[PyTorch_Learn/15_save_and_lode_notify.ipynb]中使用mothod1直接加载
torch.save(myModule, "./myModule.pth")

**使用mothod2 来保存和加载模型 验证是否存在上述问题**

In [4]:
# 在另一个类中使用mothod2直接加载模型
torch.save(myModule.state_dict(), "./myModule_state_dict.pth")