Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion extension/module/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace extension {
/**
* A facade class for loading programs and executing methods within them.
*/
class Module final {
class Module {
public:
/**
* Enum to define loading behavior.
Expand Down Expand Up @@ -337,6 +337,8 @@ class Module final {
std::unique_ptr<runtime::MemoryAllocator> memory_allocator_;
std::unique_ptr<runtime::MemoryAllocator> temp_allocator_;
std::unique_ptr<runtime::EventTracer> event_tracer_;

protected:
std::unordered_map<std::string, MethodHolder> methods_;

friend class ExecuTorchJni;
Expand Down
8 changes: 8 additions & 0 deletions extension/training/module/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets()
28 changes: 28 additions & 0 deletions extension/training/module/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

for aten_mode in (True, False):
aten_suffix = ("_aten" if aten_mode else "")

runtime.cxx_library(
name = "training_module" + aten_suffix,
srcs = [
"training_module.cpp",
],
exported_headers = [
"training_module.h",
],
visibility = [
"@EXECUTORCH_CLIENTS",
],
exported_deps = [
"//executorch/extension/module:module" + aten_suffix,
"//executorch/runtime/core:evalue" + aten_suffix,
],
)
8 changes: 8 additions & 0 deletions extension/training/module/test/TARGETS
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Any targets that should be shared between fbcode and xplat must be defined in
# targets.bzl. This file can contain fbcode-only targets.

load(":targets.bzl", "define_common_targets")

oncall("executorch")

define_common_targets(is_fbcode = True)
34 changes: 34 additions & 0 deletions extension/training/module/test/targets.bzl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

def define_common_targets(is_fbcode = False):
"""Defines targets that should be shared between fbcode and xplat.

The directory containing this targets.bzl file should also contain both
TARGETS and BUCK files that call this function.
"""

# TODO(dbort): Find a way to make these run for ANDROID/APPLE in xplat. The
# android and ios test determinators don't like the reference to the model
# file in fbcode. See https://fburl.com/9esapdmd
if not runtime.is_oss and is_fbcode:
modules_env = {
# The tests use this var to find the program file to load. This uses
# an fbcode target path because the authoring/export tools
# intentionally don't work in xplat (since they're host-only tools).
"ET_MODULE_ADD_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleAdd.pte])",
"ET_MODULE_SIMPLE_TRAIN_PATH": "$(location fbcode//executorch/test/models:exported_programs[ModuleSimpleTrain.pte])",
}

runtime.cxx_test(
name = "training_module_test",
srcs = [
"training_module_test.cpp",
],
deps = [
"//executorch/extension/training/module:training_module",
"//executorch/extension/data_loader:file_data_loader",
"//executorch/runtime/core/exec_aten/testing_util:tensor_util",
"//executorch/kernels/portable:generated_lib",
],
env = modules_env,
)
107 changes: 107 additions & 0 deletions extension/training/module/test/training_module_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/data_loader/file_data_loader.h>
#include <executorch/extension/training/module/training_module.h>

#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
#include <executorch/runtime/platform/runtime.h>
#include <gtest/gtest.h>

// @lint-ignore-every CLANGTIDY facebook-hte-CArray

using namespace ::testing;
using exec_aten::ScalarType;
using exec_aten::Tensor;
using torch::executor::Error;
using torch::executor::Span;
using torch::executor::testing::TensorFactory;

class TrainingModuleTest : public ::testing::Test {
protected:
void SetUp() override {
torch::executor::runtime_init();
}
};

TEST_F(TrainingModuleTest, JointGraphTest) {
// Create a loader for the serialized ModuleAdd program.
const char* path = std::getenv("ET_MODULE_SIMPLE_TRAIN_PATH");
executorch::runtime::Result<torch::executor::util::FileDataLoader>
loader_res = torch::executor::util::FileDataLoader::from(path);
ASSERT_EQ(loader_res.error(), Error::Ok);
auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
std::move(loader_res.get()));

auto mod = executorch::extension::training::TrainingModule(std::move(loader));

TensorFactory<ScalarType::Float> tf;
Tensor input = tf.make({3}, {1.0, 1.0, 1.0});
Tensor label = tf.make({3}, {1.0, 0.0, 0.0});

std::vector<executorch::runtime::EValue> inputs;
inputs.push_back(input);
inputs.push_back(label);

auto res = mod.execute_forward_backward("forward", inputs);
ASSERT_EQ(res.error(), Error::Ok);
ASSERT_EQ(res.get().size(), 1);

// Test Gradients
auto grad_res = mod.named_gradients("forward");
ASSERT_EQ(grad_res.error(), Error::Ok);
auto& grad = grad_res.get();
ASSERT_EQ(grad.size(), 2);
ASSERT_NE(grad.find("linear.weight"), grad.end());
ASSERT_NE(grad.find("linear.bias"), grad.end());

ASSERT_EQ(grad.find("linear.weight")->second.sizes()[0], 3);
ASSERT_EQ(grad.find("linear.weight")->second.sizes()[1], 3);
ASSERT_EQ(grad.find("linear.weight")->second.dim(), 2);
ASSERT_EQ(grad.find("linear.bias")->second.sizes()[0], 3);
ASSERT_EQ(grad.find("linear.bias")->second.dim(), 1);

// Test Parameters
auto param_res = mod.named_parameters("forward");
ASSERT_EQ(param_res.error(), Error::Ok);
auto& param = grad_res.get();
ASSERT_EQ(param.size(), 2);
ASSERT_NE(param.find("linear.weight"), grad.end());
ASSERT_NE(param.find("linear.bias"), grad.end());

ASSERT_EQ(param.find("linear.weight")->second.sizes()[0], 3);
ASSERT_EQ(param.find("linear.weight")->second.sizes()[1], 3);
ASSERT_EQ(param.find("linear.weight")->second.dim(), 2);
ASSERT_EQ(param.find("linear.bias")->second.sizes()[0], 3);
ASSERT_EQ(param.find("linear.bias")->second.dim(), 1);
}

TEST_F(TrainingModuleTest, NonTrainingModuleTest) {
// Create a loader for the serialized ModuleAdd program.
const char* path = std::getenv("ET_MODULE_ADD_PATH");
executorch::runtime::Result<torch::executor::util::FileDataLoader>
loader_res = torch::executor::util::FileDataLoader::from(path);
ASSERT_EQ(loader_res.error(), Error::Ok);
auto loader = std::make_unique<torch::executor::util::FileDataLoader>(
std::move(loader_res.get()));

auto mod = executorch::extension::training::TrainingModule(std::move(loader));

TensorFactory<ScalarType::Float> tf;
Tensor input = tf.make({2, 2}, {1.0, 1.0, 1.0, 1.0});
Tensor input2 = tf.make({2, 2}, {1.0, 0.0, 0.0, 0.0});

std::vector<executorch::runtime::EValue> inputs;
inputs.push_back(input);
inputs.push_back(input2);

// Non-training module should fail to execute forward/backward as it cant find
// the gradients or params.
auto res = mod.execute_forward_backward("forward", inputs);
ASSERT_EQ(res.error(), Error::InvalidArgument);
}
135 changes: 135 additions & 0 deletions extension/training/module/training_module.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <executorch/extension/training/module/training_module.h>

namespace executorch {
namespace extension {
namespace training {

namespace {
std::string gradients_method_prefix = "__et_training_gradients_index_";
std::string parameters_method_prefix = "__et_training_parameters_index_";
std::string fqn_method_prefix = "__et_training_fqn_";
} // namespace

runtime::Result<std::vector<runtime::EValue>>
TrainingModule::execute_forward_backward(
const std::string& method_name,
const std::vector<runtime::EValue>& input) {
// Find where the user outputs end.
const std::string gradients_method_name =
gradients_method_prefix + method_name;
auto res = executorch::extension::Module::execute(gradients_method_name);
if (!res.ok()) {
return res.error();
}
uint64_t grad_start = res.get()[0].toInt();

const std::string parameters_method_name =
parameters_method_prefix + method_name;
// get params start.
auto param_res =
executorch::extension::Module::execute(parameters_method_name);
if (!param_res.ok()) {
return param_res.error();
}

uint64_t param_start = param_res.get()[0].toInt();

// Execute the forward and backward pass.

auto outputs = torch::executor::Module::execute(method_name, input);
if (!outputs.ok()) {
return outputs.error();
}

// Extract the user outputs.
std::vector<runtime::EValue> user_outputs;
user_outputs.reserve(grad_start);
for (size_t i = 0; i < grad_start; ++i) {
user_outputs.push_back(outputs.get().at(i));
}

// Extract and store the gradients.
if (method_named_gradients_.find(method_name) ==
method_named_gradients_.end()) {
method_named_gradients_.insert({method_name, {}});

auto& gradients_map = method_named_gradients_.at(method_name);
// Get names.
const std::string fqn_method_name = fqn_method_prefix + method_name;
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
if (!fqn_res.ok()) {
return fqn_res.error();
}
const auto& fqn_list = fqn_res.get();

// Only have to initialize the dict once because the tensors in the dict and
// the tensors in the method alias the same TensorImpl, so updating one will
// update the other.
size_t name_index = 0;
for (size_t grad_index = grad_start; grad_index < param_start;
++grad_index, ++name_index) {
exec_aten::string_view fqn = fqn_list.at(name_index).toString();
gradients_map.insert({fqn, outputs.get().at(grad_index).toTensor()});
}
}

return user_outputs;
}

runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
TrainingModule::named_parameters(const std::string& method_name) {
std::map<exec_aten::string_view, exec_aten::Tensor> named_parameters;
const std::string fqn_method_name = fqn_method_prefix + method_name;
const std::string parameters_method_name =
parameters_method_prefix + method_name;

// get names.
auto fqn_res = executorch::extension::Module::execute(fqn_method_name);
if (!fqn_res.ok()) {
return fqn_res.error();
}
const auto& fqn_list = fqn_res.get();

// get params start.
auto param_res =
executorch::extension::Module::execute(parameters_method_name);
if (!param_res.ok()) {
return param_res.error();
}

uint64_t param_start = param_res.get()[0].toInt();

auto& method = methods_.at(method_name).method;

// create dict
size_t name_index = 0;
for (size_t param_index = param_start; param_index < method->outputs_size();
++param_index, ++name_index) {
exec_aten::string_view fqn = fqn_list.at(name_index).toString();
exec_aten::Tensor param = method->get_output(param_index).toTensor();
named_parameters.insert({fqn, param});
}
return named_parameters;
}

runtime::Result<const std::map<exec_aten::string_view, exec_aten::Tensor>>
TrainingModule::named_gradients(const std::string& method_name) {
if (method_named_gradients_.find(method_name) ==
method_named_gradients_.end()) {
ET_LOG(Error, "No gradients found for method %s", method_name.c_str());
return executorch::runtime::Error::InvalidArgument;
}
return method_named_gradients_.at(method_name);
}

} // namespace training
} // namespace extension
} // namespace executorch
Loading
Loading