From 385070febd2f665fb21cb1056f7183e16fe7f849 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Fri, 22 Aug 2025 14:02:59 -0700 Subject: [PATCH] Access Method directly from TrainingModule. (#13602) Summary: . Reviewed By: JacobSzwejbka, mergennachin Differential Revision: D80821085 --- extension/training/module/training_module.cpp | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/extension/training/module/training_module.cpp b/extension/training/module/training_module.cpp index 57514355f5e..a379e044503 100644 --- a/extension/training/module/training_module.cpp +++ b/extension/training/module/training_module.cpp @@ -162,15 +162,16 @@ TrainingModule::named_attributes(const std::string& method_name) { method_named_attributes_.insert({method_name, {}}); // get method metadata - auto meta_res = executorch::extension::Module::method_meta(method_name); + auto meta_res = method_meta(method_name); if (!meta_res.ok()) { return meta_res.error(); } // get method - auto method_res = executorch::extension::Module::method(method_name); - if (!method_res.ok()) { - return method_res.error(); + auto e = load_method(method_name); + if (e != runtime::Error::Ok) { + return e; } + auto& method = methods_.at(method_name).method; // get tensor by name for (int idx = 0; idx < meta_res->num_attributes(); idx++) { const auto tensor_res = meta_res->attribute_tensor_meta(idx); @@ -178,7 +179,7 @@ TrainingModule::named_attributes(const std::string& method_name) { return tensor_res.error(); } const auto tensorName = tensor_res.get().name(); - const auto attribute_res = (*method_res)->get_attribute(tensorName); + const auto attribute_res = method->get_attribute(tensorName); if (!attribute_res.ok()) { return attribute_res.error(); }