From 9e7b86ee6718f47938de0b8fe778d713bca6c412 Mon Sep 17 00:00:00 2001 From: Jacob Szwejbka Date: Wed, 4 Sep 2024 13:34:51 -0700 Subject: [PATCH] TrainingModule (#5077) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/5077 Add a module that makes the UX a lot better Reviewed By: kirklandsign Differential Revision: D62140852 --- extension/module/module.h | 4 +- extension/training/module/TARGETS | 8 ++ extension/training/module/targets.bzl | 28 ++++ extension/training/module/test/TARGETS | 8 ++ extension/training/module/test/targets.bzl | 34 +++++ .../module/test/training_module_test.cpp | 107 ++++++++++++++ extension/training/module/training_module.cpp | 135 ++++++++++++++++++ extension/training/module/training_module.h | 102 +++++++++++++ 8 files changed, 425 insertions(+), 1 deletion(-) create mode 100644 extension/training/module/TARGETS create mode 100644 extension/training/module/targets.bzl create mode 100644 extension/training/module/test/TARGETS create mode 100644 extension/training/module/test/targets.bzl create mode 100644 extension/training/module/test/training_module_test.cpp create mode 100644 extension/training/module/training_module.cpp create mode 100644 extension/training/module/training_module.h diff --git a/extension/module/module.h b/extension/module/module.h index 052489fb331..c1fe11147f7 100644 --- a/extension/module/module.h +++ b/extension/module/module.h @@ -22,7 +22,7 @@ namespace extension { /** * A facade class for loading programs and executing methods within them. */ -class Module final { +class Module { public: /** * Enum to define loading behavior. @@ -337,6 +337,8 @@ class Module final { std::unique_ptr memory_allocator_; std::unique_ptr temp_allocator_; std::unique_ptr event_tracer_; + + protected: std::unordered_map methods_; friend class ExecuTorchJni; diff --git a/extension/training/module/TARGETS b/extension/training/module/TARGETS new file mode 100644 index 00000000000..2341af9282f --- /dev/null +++ b/extension/training/module/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/extension/training/module/targets.bzl b/extension/training/module/targets.bzl new file mode 100644 index 00000000000..88da84ed131 --- /dev/null +++ b/extension/training/module/targets.bzl @@ -0,0 +1,28 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + for aten_mode in (True, False): + aten_suffix = ("_aten" if aten_mode else "") + + runtime.cxx_library( + name = "training_module" + aten_suffix, + srcs = [ + "training_module.cpp", + ], + exported_headers = [ + "training_module.h", + ], + visibility = [ + "@EXECUTORCH_CLIENTS", + ], + exported_deps = [ + "//executorch/extension/module:module" + aten_suffix, + "//executorch/runtime/core:evalue" + aten_suffix, + ], + ) diff --git a/extension/training/module/test/TARGETS b/extension/training/module/test/TARGETS new file mode 100644 index 00000000000..a6c52d105f6 --- /dev/null +++ b/extension/training/module/test/TARGETS @@ -0,0 +1,8 @@ +# Any targets that should be shared between fbcode and xplat must be defined in +# targets.bzl. This file can contain fbcode-only targets. + +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets(is_fbcode = True) diff --git a/extension/training/module/test/targets.bzl b/extension/training/module/test/targets.bzl new file mode 100644 index 00000000000..8b260e2a7e8 --- /dev/null +++ b/extension/training/module/test/targets.bzl @@ -0,0 +1,34 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(is_fbcode = False): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + + # TODO(dbort): Find a way to make these run for ANDROID/APPLE in xplat. The + # android and ios test determinators don't like the reference to the model + # file in fbcode. See https://fburl.com/9esapdmd + if not runtime.is_oss and is_fbcode: + modules_env = { + # The tests use this var to find the program file to load. This uses + # an fbcode target path because the authoring/export tools + # intentionally don't work in xplat (since they're host-only tools). + "ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])", + "ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])", + } + + runtime.cxx_test( + name = "training_module_test", + srcs = [ + "training_module_test.cpp", + ], + deps = [ + "//executorch/extension/training/module:training_module", + "//executorch/extension/data_loader:file_data_loader", + "//executorch/runtime/core/exec_aten/testing_util:tensor_util", + "//executorch/kernels/portable:generated_lib", + ], + env = modules_env, + ) diff --git a/extension/training/module/test/training_module_test.cpp b/extension/training/module/test/training_module_test.cpp new file mode 100644 index 00000000000..58631c4cf44 --- /dev/null +++ b/extension/training/module/test/training_module_test.cpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include + +#include +#include +#include + +// @lint-ignore-every CLANGTIDY facebook-hte-CArray + +using namespace ::testing; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::Error; +using torch::executor::Span; +using torch::executor::testing::TensorFactory; + +class TrainingModuleTest : public ::testing::Test { + protected: + void SetUp() override { + torch::executor::runtime_init(); + } +}; + +TEST_F(TrainingModuleTest, JointGraphTest) { + // Create a loader for the serialized ModuleAdd program. + const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH"); + executorch::runtime::Result + loader_res = torch::executor::util::FileDataLoader::from(path); + ASSERT_EQ(loader_res.error(), Error::Ok); + auto loader = std::make_unique( + std::move(loader_res.get())); + + auto mod = executorch::extension::training::TrainingModule(std::move(loader)); + + TensorFactory tf; + Tensor input = tf.make({3}, {1.0, 1.0, 1.0}); + Tensor label = tf.make({3}, {1.0, 0.0, 0.0}); + + std::vector inputs; + inputs.push_back(input); + inputs.push_back(label); + + auto res = mod.execute_forward_backward("forward", inputs); + ASSERT_EQ(res.error(), Error::Ok); + ASSERT_EQ(res.get().size(), 1); + + // Test Gradients + auto grad_res = mod.named_gradients("forward"); + ASSERT_EQ(grad_res.error(), Error::Ok); + auto& grad = grad_res.get(); + ASSERT_EQ(grad.size(), 2); + ASSERT_NE(grad.find("linear.weight"), grad.end()); + ASSERT_NE(grad.find("linear.bias"), grad.end()); + + ASSERT_EQ(grad.find("linear.weight")->second.sizes()[0], 3); + ASSERT_EQ(grad.find("linear.weight")->second.sizes()[1], 3); + ASSERT_EQ(grad.find("linear.weight")->second.dim(), 2); + ASSERT_EQ(grad.find("linear.bias")->second.sizes()[0], 3); + ASSERT_EQ(grad.find("linear.bias")->second.dim(), 1); + + // Test Parameters + auto param_res = mod.named_parameters("forward"); + ASSERT_EQ(param_res.error(), Error::Ok); + auto& param = grad_res.get(); + ASSERT_EQ(param.size(), 2); + ASSERT_NE(param.find("linear.weight"), grad.end()); + ASSERT_NE(param.find("linear.bias"), grad.end()); + + ASSERT_EQ(param.find("linear.weight")->second.sizes()[0], 3); + ASSERT_EQ(param.find("linear.weight")->second.sizes()[1], 3); + ASSERT_EQ(param.find("linear.weight")->second.dim(), 2); + ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3); + ASSERT_EQ(param.find("linear.bias")->second.dim(), 1); +} + +TEST_F(TrainingModuleTest, NonTrainingModuleTest) { + // Create a loader for the serialized ModuleAdd program. + const char* path = std::getenv("ET_MODULE_ADD_PATH"); + executorch::runtime::Result + loader_res = torch::executor::util::FileDataLoader::from(path); + ASSERT_EQ(loader_res.error(), Error::Ok); + auto loader = std::make_unique( + std::move(loader_res.get())); + + auto mod = executorch::extension::training::TrainingModule(std::move(loader)); + + TensorFactory tf; + Tensor input = tf.make({2, 2}, {1.0, 1.0, 1.0, 1.0}); + Tensor input2 = tf.make({2, 2}, {1.0, 0.0, 0.0, 0.0}); + + std::vector inputs; + inputs.push_back(input); + inputs.push_back(input2); + + // Non-training module should fail to execute forward/backward as it cant find + // the gradients or params. + auto res = mod.execute_forward_backward("forward", inputs); + ASSERT_EQ(res.error(), Error::InvalidArgument); +} diff --git a/extension/training/module/training_module.cpp b/extension/training/module/training_module.cpp new file mode 100644 index 00000000000..7b38292fd1f --- /dev/null +++ b/extension/training/module/training_module.cpp @@ -0,0 +1,135 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +namespace executorch { +namespace extension { +namespace training { + +namespace { +std::string gradients_method_prefix = "__et_training_gradients_index_"; +std::string parameters_method_prefix = "__et_training_parameters_index_"; +std::string fqn_method_prefix = "__et_training_fqn_"; +} // namespace + +runtime::Result> +TrainingModule::execute_forward_backward( + const std::string& method_name, + const std::vector& input) { + // Find where the user outputs end. + const std::string gradients_method_name = + gradients_method_prefix + method_name; + auto res = executorch::extension::Module::execute(gradients_method_name); + if (!res.ok()) { + return res.error(); + } + uint64_t grad_start = res.get()[0].toInt(); + + const std::string parameters_method_name = + parameters_method_prefix + method_name; + // get params start. + auto param_res = + executorch::extension::Module::execute(parameters_method_name); + if (!param_res.ok()) { + return param_res.error(); + } + + uint64_t param_start = param_res.get()[0].toInt(); + + // Execute the forward and backward pass. + + auto outputs = torch::executor::Module::execute(method_name, input); + if (!outputs.ok()) { + return outputs.error(); + } + + // Extract the user outputs. + std::vector user_outputs; + user_outputs.reserve(grad_start); + for (size_t i = 0; i < grad_start; ++i) { + user_outputs.push_back(outputs.get().at(i)); + } + + // Extract and store the gradients. + if (method_named_gradients_.find(method_name) == + method_named_gradients_.end()) { + method_named_gradients_.insert({method_name, {}}); + + auto& gradients_map = method_named_gradients_.at(method_name); + // Get names. + const std::string fqn_method_name = fqn_method_prefix + method_name; + auto fqn_res = executorch::extension::Module::execute(fqn_method_name); + if (!fqn_res.ok()) { + return fqn_res.error(); + } + const auto& fqn_list = fqn_res.get(); + + // Only have to initialize the dict once because the tensors in the dict and + // the tensors in the method alias the same TensorImpl, so updating one will + // update the other. + size_t name_index = 0; + for (size_t grad_index = grad_start; grad_index < param_start; + ++grad_index, ++name_index) { + exec_aten::string_view fqn = fqn_list.at(name_index).toString(); + gradients_map.insert({fqn, outputs.get().at(grad_index).toTensor()}); + } + } + + return user_outputs; +} + +runtime::Result> +TrainingModule::named_parameters(const std::string& method_name) { + std::map named_parameters; + const std::string fqn_method_name = fqn_method_prefix + method_name; + const std::string parameters_method_name = + parameters_method_prefix + method_name; + + // get names. + auto fqn_res = executorch::extension::Module::execute(fqn_method_name); + if (!fqn_res.ok()) { + return fqn_res.error(); + } + const auto& fqn_list = fqn_res.get(); + + // get params start. + auto param_res = + executorch::extension::Module::execute(parameters_method_name); + if (!param_res.ok()) { + return param_res.error(); + } + + uint64_t param_start = param_res.get()[0].toInt(); + + auto& method = methods_.at(method_name).method; + + // create dict + size_t name_index = 0; + for (size_t param_index = param_start; param_index < method->outputs_size(); + ++param_index, ++name_index) { + exec_aten::string_view fqn = fqn_list.at(name_index).toString(); + exec_aten::Tensor param = method->get_output(param_index).toTensor(); + named_parameters.insert({fqn, param}); + } + return named_parameters; +} + +runtime::Result> +TrainingModule::named_gradients(const std::string& method_name) { + if (method_named_gradients_.find(method_name) == + method_named_gradients_.end()) { + ET_LOG(Error, "No gradients found for method %s", method_name.c_str()); + return executorch::runtime::Error::InvalidArgument; + } + return method_named_gradients_.at(method_name); +} + +} // namespace training +} // namespace extension +} // namespace executorch diff --git a/extension/training/module/training_module.h b/extension/training/module/training_module.h new file mode 100644 index 00000000000..7571aacecf6 --- /dev/null +++ b/extension/training/module/training_module.h @@ -0,0 +1,102 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace extension { +namespace training { + +/** + * A facade class for loading programs for on-device training and executing + * methods within them. + */ +class ET_EXPERIMENTAL TrainingModule final : executorch::extension::Module { + public: + explicit TrainingModule( + std::unique_ptr data_loader, + std::unique_ptr memory_allocator = nullptr, + std::unique_ptr temp_allocator = nullptr, + std::unique_ptr event_tracer = nullptr) + : executorch::extension::Module( + std::move(data_loader), + std::move(memory_allocator), + std::move(temp_allocator), + std::move(event_tracer)), + method_named_gradients_({}) {} + + explicit TrainingModule(const Module&) = delete; + TrainingModule& operator=(const Module&) = delete; + explicit TrainingModule(Module&&) = delete; + TrainingModule& operator=(Module&&) = delete; + + /** + * Execute a specific method with the given input and retrieve output. Only + * valid if the specified method is a joint graph. Loads the program and + * method before executing if needed. + * + * @param[in] method_name The name of the joint graph method to execute. + * @param[in] input A vector of input values to be passed to the method. + * + * @returns A Result object containing the output values from the method or an + * error to indicate failure. + */ + ET_EXPERIMENTAL runtime::Result> + execute_forward_backward( + const std::string& method_name, + const std::vector& input); + + /** + * Retrieve the trainable parameters for a joint graph method. + * + * @param[in] method_name The name of the joint graph method to get the + * parameters for. + * + * @returns A Result object containing a map of the fully qualified name to + * parameter tensor, or an error if the method is not a joint graph or has not + * been executed yet. + */ + ET_EXPERIMENTAL + runtime::Result> + named_parameters(const std::string& method_name); + + /** + * Retrieve the latest gradients for a joint graph method. + * + * @param[in] method_name The name of the joint graph method to get the + * gradients for. + * + * @returns A Result object containing a map of the fully qualified name to + * gradient tensor associated with that parameter from the latest + * forward_backward execution, or an error if the method is not a joint graph + * or has not been executed yet. + */ + ET_EXPERIMENTAL + runtime::Result> + named_gradients(const std::string& method_name); + + private: + std::unordered_map< + std::string, + std::map> + method_named_gradients_; +}; + +} // namespace training +} // namespace extension +} // namespace executorch