Skip to content

Commit

Permalink
update test/cpp/api/serialize.cpp to test optimizer options serializi…
Browse files Browse the repository at this point in the history
…ng, format code
  • Loading branch information
daizhirui committed May 4, 2024
1 parent 6294f5d commit 1f59509
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
6 changes: 6 additions & 0 deletions test/cpp/api/serialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,12 @@ void test_serialize_optimizer(
auto optim2_2 = OptimizerClass(model2->parameters(), options);
auto optim3 = OptimizerClass(model3->parameters(), options);
auto optim3_2 = OptimizerClass(model3->parameters(), options);
for (auto& param_group : optim3_2.param_groups()) {
const double lr = param_group.options().get_lr();
// change the learning rate, which will be overwritten by the loading
// otherwise, test cannot check if options are saved and loaded correctly
param_group.options().set_lr(lr + 1);
}

auto x = torch::ones({10, 5});

Expand Down
6 changes: 4 additions & 2 deletions torch/csrc/api/include/torch/optim/serialize.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,10 @@ void serialize(serialize::InputArchive& archive, Optimizer& optimizer) {
}
}

auto &saved_options = reinterpret_cast<DerivedOptimizerParamOptions&>(*saved_param_groups[i].second);
auto &current_options = reinterpret_cast<DerivedOptimizerParamOptions&>(optimizer.param_groups()[i].options());
auto& saved_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
*saved_param_groups[i].second);
auto& current_options = reinterpret_cast<DerivedOptimizerParamOptions&>(
optimizer.param_groups()[i].options());
current_options = saved_options;
}
}
Expand Down

0 comments on commit 1f59509

Please sign in to comment.