-
Notifications
You must be signed in to change notification settings - Fork 21.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: 18fa05bdc2b2daed93a2987b1731d7448a08d5bc Pull Request resolved: #49547
- Loading branch information
Showing
7 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# Basic CMake setup | ||
cmake_minimum_required(VERSION 3.1 FATAL_ERROR) | ||
project(jit_hooks) | ||
|
||
find_package(Torch REQUIRED) | ||
|
||
add_executable(test_jit_hooks test_jit_hooks.cpp) | ||
set_property(TARGET test_jit_hooks PROPERTY CXX_STANDARD 14) | ||
target_link_libraries(test_jit_hooks "${TORCH_LIBRARIES}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import argparse | ||
import os | ||
import sys | ||
import torch | ||
|
||
# grab modules from test_jit_hooks.cpp | ||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | ||
sys.path.append(pytorch_test_dir) | ||
from jit.test_hooks_modules import * | ||
|
||
# Create saved modules for JIT forward hooks and pre-hooks | ||
def main(): | ||
parser = argparse.ArgumentParser( | ||
description="Serialize a script modules with hooks attached" | ||
) | ||
parser.add_argument("--export-script-module-to", required=True) | ||
options = parser.parse_args() | ||
global save_name | ||
save_name = options.export_script_module_to + "_" | ||
|
||
tests = [ | ||
("test_submodule_forward_single_input", create_submodule_forward_single_input()), | ||
("test_submodule_forward_multiple_inputs", create_submodule_forward_multiple_inputs()), | ||
("test_submodule_multiple_hooks_single_input", create_submodule_multiple_hooks_single_input()), | ||
("test_submodule_multiple_hooks_multiple_inputs", create_submodule_multiple_hooks_multiple_inputs()), | ||
("test_submodule_hook_return_nothing", create_submodule_hook_return_nothing()), | ||
("test_submodule_same_hook_repeated", create_submodule_same_hook_repeated()), | ||
|
||
("test_module_forward_single_input", create_module_forward_single_input()), | ||
("test_module_forward_multiple_inputs", create_module_forward_multiple_inputs()), | ||
("test_module_multiple_hooks_single_input", create_module_multiple_hooks_single_input()), | ||
("test_module_multiple_hooks_multiple_inputs", create_module_multiple_hooks_multiple_inputs()), | ||
("test_module_hook_return_nothing", create_module_hook_return_nothing()), | ||
("test_module_same_hook_repeated", create_module_same_hook_repeated()), | ||
|
||
("test_module_no_forward_input", create_module_no_forward_input()), | ||
("test_forward_tuple_input", create_forward_tuple_input()), | ||
("test_submodule_to_call_directly_with_hooks", create_submodule_to_call_directly_with_hooks()) | ||
] | ||
|
||
for name, model in tests: | ||
m_scripted = torch.jit.script(model) | ||
filename = save_name + name + ".pt" | ||
torch.jit.save(m_scripted, filename) | ||
|
||
print("OK: completed saving modules with hooks!") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#include <torch/script.h> | ||
|
||
#include <memory> | ||
#include <string> | ||
#include <sstream> | ||
#include <vector> | ||
|
||
#include <iostream> | ||
|
||
void test_module_forward_invocation_no_hooks_run( | ||
const std::string &path_to_exported_script_module) { | ||
std::cout << "testing: " | ||
<< "test_module_forward_invocation_no_hooks_run" << std::endl; | ||
torch::jit::Module module = | ||
torch::jit::load(path_to_exported_script_module + "_" + | ||
"test_module_forward_multiple_inputs" + ".pt"); | ||
std::vector<torch::jit::IValue> inputs = {torch::List<std::string>({"a"}), | ||
torch::jit::IValue("no_pre_hook")}; | ||
|
||
auto output = module(inputs); | ||
auto output_forward = module.forward(inputs); | ||
torch::jit::IValue correct_direct_output = | ||
std::tuple<torch::List<std::string>, std::string>( | ||
{"a", "outer_mod_name", "inner_mod_name"}, "no_pre_hook_"); | ||
std::cout << "----- module output: " << output << std::endl; | ||
std::cout << "----- module forward output: " << output_forward << std::endl; | ||
AT_ASSERT(correct_direct_output == output_forward); | ||
} | ||
|
||
void test_submodule_called_directly_with_hooks( | ||
const std::string &path_to_exported_script_module) { | ||
std::cout << "testing: " | ||
<< "test_submodule_to_call_directly_with_hooks" << std::endl; | ||
torch::jit::Module module = | ||
torch::jit::load(path_to_exported_script_module + "_" + | ||
"test_submodule_to_call_directly_with_hooks" + ".pt"); | ||
torch::jit::Module submodule = *module.modules().begin(); | ||
std::vector<torch::jit::IValue> inputs = {"a"}; | ||
|
||
auto output = submodule(inputs); | ||
torch::jit::IValue correct_output = "pre_hook_override_name_inner_mod_fh"; | ||
std::cout << "----- submodule's output: " << output << std::endl; | ||
std::cout << "----- expected output : " << correct_output << std::endl; | ||
AT_ASSERT(correct_output == correct_output); | ||
} | ||
|
||
struct HooksTestCase { | ||
std::string name; | ||
std::vector<torch::jit::IValue> inputs; | ||
torch::jit::IValue output; | ||
HooksTestCase(std::string name, std::vector<torch::jit::IValue> inputs, | ||
torch::jit::IValue output) | ||
: name(name), inputs(std::move(inputs)), output(std::move(output)) {} | ||
}; | ||
|
||
int main(int argc, const char *argv[]) { | ||
if (argc != 2) { | ||
std::cerr << "usage: test_jit_hooks <path-to-exported-script-module>\n"; | ||
return -1; | ||
} | ||
const std::string path_to_exported_script_module = argv[1]; | ||
std::cout << "path to exported module:" << path_to_exported_script_module | ||
<< std::endl; | ||
std::cout << "Tesing JIT Hooks in CPP" << std::endl; | ||
|
||
// Note: Modules loaded in this file are produced in /test/jit_hooks/model.py | ||
|
||
std::vector<HooksTestCase> test_cases = { | ||
HooksTestCase("test_submodule_multiple_hooks_single_input", | ||
{torch::jit::IValue("a")}, | ||
"pre_hook_override_name2_inner_mod_fwh1"), | ||
HooksTestCase("test_submodule_hook_return_nothing", | ||
{torch::jit::IValue("a")}, "a_outermod_inner_mod"), | ||
HooksTestCase("test_submodule_same_hook_repeated", | ||
{torch::jit::IValue("a")}, | ||
"a_outermod_ph_ph_inner_mod_fh_fh"), | ||
HooksTestCase("test_submodule_forward_single_input", | ||
{torch::jit::IValue("a")}, | ||
"pre_hook_override_name_inner_mod"), | ||
HooksTestCase( | ||
"test_submodule_multiple_hooks_multiple_inputs", | ||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")}, | ||
std::tuple<torch::List<std::string>, std::string>( | ||
{"pre_hook_override_name", "inner_mod_name"}, | ||
"pre_hook_override2_fh1_fh2")), | ||
HooksTestCase( | ||
"test_submodule_forward_multiple_inputs", | ||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")}, | ||
std::tuple<torch::List<std::string>, std::string>( | ||
{"pre_hook_override_name", "inner_mod_name"}, | ||
"pre_hook_override_fh")), | ||
HooksTestCase("test_module_forward_single_input", | ||
{torch::jit::IValue("a")}, | ||
"pre_hook_override_name_outermod_inner_mod_fh"), | ||
HooksTestCase("test_module_multiple_hooks_single_input", | ||
{torch::jit::IValue("a")}, | ||
"pre_hook_override_name2_outermod_inner_mod_fh1_fh2"), | ||
HooksTestCase("test_module_hook_return_nothing", | ||
{torch::jit::IValue("a")}, "a_outermod_inner_mod"), | ||
HooksTestCase("test_module_same_hook_repeated", {torch::jit::IValue("a")}, | ||
"a_ph_ph_outermod_inner_mod_fh_fh"), | ||
HooksTestCase( | ||
"test_module_forward_multiple_inputs", | ||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")}, | ||
std::tuple<torch::List<std::string>, std::string>( | ||
{"pre_hook_override_name", "outer_mod_name", "inner_mod_name"}, | ||
"pre_hook_override_fh")), | ||
HooksTestCase( | ||
"test_module_multiple_hooks_multiple_inputs", | ||
{torch::List<std::string>({"a"}), torch::jit::IValue("no_pre_hook")}, | ||
std::tuple<torch::List<std::string>, std::string>( | ||
{"pre_hook_override_name2", "outer_mod_name", "inner_mod_name"}, | ||
"pre_hook_override_fh1_fh2")), | ||
HooksTestCase("test_module_no_forward_input", {}, torch::jit::IValue()), | ||
HooksTestCase("test_forward_tuple_input", {std::tuple<int>(11)}, | ||
{std::tuple<int>(11)}), | ||
}; | ||
|
||
for (HooksTestCase &test_case : test_cases) { | ||
std::cout << "testing: " << test_case.name << std::endl; | ||
torch::jit::Module module = torch::jit::load( | ||
path_to_exported_script_module + "_" + test_case.name + ".pt"); | ||
torch::jit::IValue output = module(test_case.inputs); | ||
std::cout << "----- module's output: " << output << std::endl; | ||
std::cout << "----- expected output: " << test_case.output << std::endl; | ||
AT_ASSERT(output == test_case.output); | ||
} | ||
|
||
// special test cases that don't call the imported module directly | ||
test_module_forward_invocation_no_hooks_run(path_to_exported_script_module); | ||
test_submodule_called_directly_with_hooks(path_to_exported_script_module); | ||
|
||
std::cout << "JIT CPP Hooks okay!" << std::endl; | ||
|
||
return 0; | ||
} |