In [2]:
# save and load tensor or paramters of network
import torch
from torch import nn
from torch.nn import functional as F

# save and load tensor
x = torch.arange(4)
print(x)
torch.save(x, 'x_file')

x1 = torch.load('x_file')
print(x1)

tensor([0, 1, 2, 3])
tensor([0, 1, 2, 3])


In [3]:
# save and load tensor list
y = torch.zeros(4)
torch.save([x, y], 'x_y_file')
print((x, y))

x2, y2 = torch.load('x_y_file')
print((x2, y2))

(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))
(tensor([0, 1, 2, 3]), tensor([0., 0., 0., 0.]))


In [4]:
# save and load tensor dictionary
dict = {'x' : x, 'y' : y}
torch.save(dict, 'dict_file')
print(dict)

dict1 = torch.load('dict_file')
print(dict1)

{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}
{'x': tensor([0, 1, 2, 3]), 'y': tensor([0., 0., 0., 0.])}


In [7]:
# save and load model throught parameters
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(20, 256)
        self.output = nn.Linear(256, 10)

    def forward(self, x):
        return self.output(F.relu(self.hidden(x)))

net = MLP()
X = torch.randn(size=(2, 20))
y = net(X)

# save
torch.save(net.state_dict(), 'mlp_param_file')
print(y)

# load: we have to instance a MLP() object to retrieve saved parameters
net_clone = MLP()
net_clone.load_state_dict(torch.load('mlp_param_file'))
y_clone = net_clone(X)
print(y_clone)
print(y_clone == y)

tensor([[-0.0348,  0.0393,  0.0886,  0.2006, -0.3367,  0.1010,  0.0754, -0.1596,
         -0.0327,  0.2550],
        [-0.0833,  0.1148,  0.1399,  0.2755,  0.0047,  0.1125,  0.1894, -0.1546,
          0.3588,  0.0715]], grad_fn=<AddmmBackward0>)
tensor([[-0.0348,  0.0393,  0.0886,  0.2006, -0.3367,  0.1010,  0.0754, -0.1596,
         -0.0327,  0.2550],
        [-0.0833,  0.1148,  0.1399,  0.2755,  0.0047,  0.1125,  0.1894, -0.1546,
          0.3588,  0.0715]], grad_fn=<AddmmBackward0>)
tensor([[True, True, True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True, True, True]])
