diff --git a/extension/training/module/training_module.cpp b/extension/training/module/training_module.cpp index 57514355f5e..a379e044503 100644 --- a/extension/training/module/training_module.cpp +++ b/extension/training/module/training_module.cpp @@ -162,15 +162,16 @@ TrainingModule::named_attributes(const std::string& method_name) { method_named_attributes_.insert({method_name, {}}); // get method metadata - auto meta_res = executorch::extension::Module::method_meta(method_name); + auto meta_res = method_meta(method_name); if (!meta_res.ok()) { return meta_res.error(); } // get method - auto method_res = executorch::extension::Module::method(method_name); - if (!method_res.ok()) { - return method_res.error(); + auto e = load_method(method_name); + if (e != runtime::Error::Ok) { + return e; } + auto& method = methods_.at(method_name).method; // get tensor by name for (int idx = 0; idx < meta_res->num_attributes(); idx++) { const auto tensor_res = meta_res->attribute_tensor_meta(idx); @@ -178,7 +179,7 @@ TrainingModule::named_attributes(const std::string& method_name) { return tensor_res.error(); } const auto tensorName = tensor_res.get().name(); - const auto attribute_res = (*method_res)->get_attribute(tensorName); + const auto attribute_res = method->get_attribute(tensorName); if (!attribute_res.ok()) { return attribute_res.error(); }