In [1]:
import torch
import torch.nn as nn

In [2]:
class Model(nn.Module):
    def __init__(self,n_input_features):
        super(Model,self).__init__()
        self.linear=nn.Linear(n_input_features,1)
    
    def forward(self,x):
        y_pred=torch.sigmoid(self.linear(x))
        return y_pred

In [3]:
model=Model(n_input_features=6)

In [4]:
# lazy method.
FILE = "model.pth" 
torch.save(model,FILE)

In [5]:
FILE = "model.pth" 
model=torch.load(FILE)
model.eval()

for param in model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.2835,  0.3547,  0.0802,  0.2204, -0.4073, -0.0075]],
       requires_grad=True)
Parameter containing:
tensor([0.2868], requires_grad=True)


In [7]:
# preferred way
FILE = "model.pth" 
torch.save(model.state_dict(),FILE)

# for param in model.parameters():
#     print(param)

In [8]:
loaded_model=Model(n_input_features=6)
loaded_model.load_state_dict(torch.load(FILE))
loaded_model.eval()

for param in loaded_model.parameters():
    print(param)

Parameter containing:
tensor([[ 0.2835,  0.3547,  0.0802,  0.2204, -0.4073, -0.0075]],
       requires_grad=True)
Parameter containing:
tensor([0.2868], requires_grad=True)


In [13]:
model=Model(n_input_features=6)
learning_rate=0.01
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)
print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'params': [0, 1]}]}


In [14]:
# checkpoint
checkpoint={
    "epoch":90, 
    "model_state":model.state_dict(),
    "optim_state":optimizer.state_dict(),
}

torch.save(checkpoint,"checkpoint.pth")

In [15]:
loaded_checkpoint=torch.load("checkpoint.pth")
epoch=loaded_checkpoint["epoch"]
model=Model(n_input_features=6)
learning_rate=0.0
optimizer=torch.optim.SGD(model.parameters(),lr=learning_rate)

model.load_state_dict(checkpoint["model_state"])
optimizer.load_state_dict(checkpoint["optim_state"])

print(optimizer.state_dict())

{'state': {}, 'param_groups': [{'lr': 0.01, 'momentum': 0, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'maximize': False, 'params': [0, 1]}]}
