diff --git a/recipes_source/recipes/saving_and_loading_a_general_checkpoint.py b/recipes_source/recipes/saving_and_loading_a_general_checkpoint.py index a31f43970f6..cc872ae8042 100644 --- a/recipes_source/recipes/saving_and_loading_a_general_checkpoint.py +++ b/recipes_source/recipes/saving_and_loading_a_general_checkpoint.py @@ -129,7 +129,7 @@ def forward(self, x): # model = Net() -optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) +optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict'])