From 2ae7cd1a6ef5305da705b411c7ffa30be862adaf Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 18 Feb 2021 01:20:55 -0600 Subject: [PATCH 1/5] Created ELU converter, compiled it to .so and run it in Python Signed-off-by: Bo Wang --- examples/README.md | 159 +++++++++++++++++++ examples/elu_converter/elu_converter.cpp | 33 ++++ examples/elu_converter/elu_converter_test.py | 44 +++++ examples/elu_converter/setup.py | 18 +++ 4 files changed, 254 insertions(+) create mode 100644 examples/README.md create mode 100644 examples/elu_converter/elu_converter.cpp create mode 100644 examples/elu_converter/elu_converter_test.py create mode 100644 examples/elu_converter/setup.py diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000000..921ecec4c8 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,159 @@ +# Create a new op in C++, compile it to .so library and load it in Python + +There are some operators in PyTorch library which are not supported in TRTorch. +To support these ops, users can register converters for missing ops. For example, +if we try to compile a graph with a build of TRTorch that doesn't support the +[ELU](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) operation, +we will get following error: + +> Unable to convert node: %result.2 : Tensor = aten::elu(%x.1, %2, %3, %3) # /home/bowa/.local/lib/python3.6/site-packages/torch/nn/functional.py:1227:17 (conversion.AddLayer) +Schema: aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor) +Converter for aten::elu requested, but no such converter was found. +If you need a converter for this operator, you can try implementing one yourself +or request a converter: https://www.github.com/NVIDIA/TRTorch/issues + +## Writing Converter in C++ +We can register a converter for this operator in our application. You can find more +information on all the details of writing converters in the contributors documentation +([Writing Converters](https://nvidia.github.io/TRTorch/contributors/writing_converters.html)). +Once we are clear about these rules and writing patterns, we can create a seperate new C++ source file as: + +```c++ +#include "core/conversion/converters/converters.h" +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto alpha = args[1].unwrapToDouble(); + + auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU); + TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu"); + + new_layer->setAlpha(alpha); + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}); + +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch +``` + +## Generate `.so` library +To use this converter in Python, it is recommended to use PyTorch's +[C++/CUDA Extension](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions). +We give an example here about how to wrap the converter into a `.so` +library so that you can load it to use in Python applicaton. +```python +import os +from setuptools import setup, Extension +from torch.utils import cpp_extension + +dir_path = os.path.dirname(os.path.realpath(__file__)) + +ext_modules = [ + cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], + library_dirs=[( + dir_path + "/../../bazel-bin/cpp/api/lib/" + )], + libraries=["trtorch"], + include_dirs=[dir_path + "/../../"] + ) +] + +setup( + name='elu_converter', + ext_modules=ext_modules, + cmdclass={'build_ext': cpp_extension.BuildExtension}, +) +``` +Make sure to include the path for header files in `include_dirs` and the path +for dependent libraries in `library_dirs`. You could also add other compilation +flags in cpp_extension if you need. Then, run above python scripts as: +```shell +python3 setup.py install --user +``` +You should see the output similar to the contents indicated [here](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions) after running +`python setup.py install`. You should find a couple of new folders generated +by the command above. In build folder, you can find the generated `.so` library, +which could be loaded in our Python application. + +## Load `.so` in Python Application +With the new generated library, TRTorch now support the new developed converter. +We use `torch.ops.load_library` to load `.so`. For example, we could load the ELU +converter and use it in our application: +```python +import torch +import trtorch + +torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') + +class Elu(torch.nn.Module): + def __init__(self): + super(Elu, self).__init__() + self.elu = torch.nn.ELU() + + def forward(self, x): + return self.elu(x) + +def main(): + data = torch.randn((1, 1, 2, 2)).to("cuda") + model = Elu().eval() #.cuda() + + scripted_model = torch.jit.script(model) + print(scripted_model.graph) + compile_settings = { + "input_shapes": [{ + "min": [1024, 1, 32, 32], + "opt": [1024, 1, 33, 33], + "max": [1024, 1, 34, 34], + }], + "op_precision": + torch.half # Run with FP16 + } + trt_ts_module = trtorch.compile(scripted_model, compile_settings) + input_data = torch.randn((1024, 1, 32, 32)) + print(input_data[0, :, :, 0]) + input_data = input_data.half().to("cuda") + result = trt_ts_module(input_data) + print(result[0, :, :, 0]) + +if __name__ == "__main__": + main() + +``` +Run this script, we can get the Tensor before and after ELU operator. +### Example Output +```bash +graph(%self : __torch__.Elu, + %x.1 : Tensor): + %2 : __torch__.torch.nn.modules.activation.ELU = prim::GetAttr[name="elu"](%self) + %4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # elu_converter_test.py:13:15 + return (%4) + +tensor([[ 1.3482, 1.9848, -1.0818, -1.3252, 0.2470, 0.7011, 0.3174, -1.8349, + 0.3024, -0.0453, -0.0681, -1.7377, 1.5909, 0.2549, -0.3029, 0.2583, + 0.0242, 2.0748, -0.5454, 0.7137, 1.6688, 0.7108, -0.8681, 0.2486, + -1.3981, 1.0241, 1.2413, 0.2725, 1.4265, 0.9329, 0.4020, -2.6813]]) +tensor([[ 1.3486, 1.9844, -0.6611, -0.7344, 0.2471, 0.7012, 0.3174, -0.8403, + 0.3025, -0.0443, -0.0659, -0.8242, 1.5908, 0.2549, -0.2615, 0.2583, + 0.0242, 2.0742, -0.4204, 0.7139, 1.6689, 0.7109, -0.5801, 0.2485, + -0.7529, 1.0244, 1.2412, 0.2725, 1.4268, 0.9331, 0.4021, -0.9316]], + device='cuda:0', dtype=torch.float16) + +``` diff --git a/examples/elu_converter/elu_converter.cpp b/examples/elu_converter/elu_converter.cpp new file mode 100644 index 0000000000..c10b49f007 --- /dev/null +++ b/examples/elu_converter/elu_converter.cpp @@ -0,0 +1,33 @@ +#include "core/conversion/converters/converters.h" +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace conversion { +namespace converters { +namespace impl { +namespace { + +auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( + {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto alpha = args[1].unwrapToDouble(); + + auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU); + TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu"); + + new_layer->setAlpha(alpha); + new_layer->setName(util::node_info(n).c_str()); + auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); + + LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); + return true; + }}); + +} // namespace +} // namespace impl +} // namespace converters +} // namespace conversion +} // namespace core +} // namespace trtorch diff --git a/examples/elu_converter/elu_converter_test.py b/examples/elu_converter/elu_converter_test.py new file mode 100644 index 0000000000..342dbbffe7 --- /dev/null +++ b/examples/elu_converter/elu_converter_test.py @@ -0,0 +1,44 @@ +import torch +import trtorch + +torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') + + +class Elu(torch.nn.Module): + + def __init__(self): + super(Elu, self).__init__() + self.elu = torch.nn.ELU() + + def forward(self, x): + return self.elu(x) + + +def main(): + data = torch.randn((1, 1, 2, 2)).to("cuda") + model = Elu().eval() #.cuda() + + # traced_model = torch.jit.trace(model, [data]) + scripted_model = torch.jit.script(model) + print(scripted_model.graph) + # torch.jit.save(scripted_model, 'elu.jit') + compile_settings = { + "input_shapes": [{ + "min": [1024, 1, 32, 32], + "opt": [1024, 1, 33, 33], + "max": [1024, 1, 34, 34], + }], + "op_precision": + torch.half # Run with FP16 + } + trt_ts_module = trtorch.compile(scripted_model, compile_settings) + input_data = torch.randn((1024, 1, 32, 32)) + print(input_data[0, :, :, 0]) + input_data = input_data.half().to("cuda") + result = trt_ts_module(input_data) + print(result[0, :, :, 0]) + # torch.jit.save(trt_ts_module, "trt_ts_module.ts") + + +if __name__ == "__main__": + main() diff --git a/examples/elu_converter/setup.py b/examples/elu_converter/setup.py new file mode 100644 index 0000000000..cc62d4f5cd --- /dev/null +++ b/examples/elu_converter/setup.py @@ -0,0 +1,18 @@ +import os +from setuptools import setup, Extension +from torch.utils import cpp_extension + +dir_path = os.path.dirname(os.path.realpath(__file__)) + +ext_modules = [ + cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], + library_dirs=[(dir_path + "/../../bazel-bin/cpp/api/lib/")], + libraries=["trtorch"], + include_dirs=[dir_path + "/../../"]) +] + +setup( + name='elu_converter', + ext_modules=ext_modules, + cmdclass={'build_ext': cpp_extension.BuildExtension}, +) From 8b6d80cfc75ff5aaceb4e0a00541c9b93f12ff0e Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 18 Feb 2021 14:12:39 -0600 Subject: [PATCH 2/5] Updated the code for requested changes Signed-off-by: Bo Wang --- examples/README.md | 45 +++++++++++--------- examples/elu_converter/elu_converter_test.py | 22 ++++++---- examples/elu_converter/setup.py | 10 ++++- 3 files changed, 47 insertions(+), 30 deletions(-) diff --git a/examples/README.md b/examples/README.md index 921ecec4c8..c608a0748a 100644 --- a/examples/README.md +++ b/examples/README.md @@ -66,14 +66,17 @@ from torch.utils import cpp_extension dir_path = os.path.dirname(os.path.realpath(__file__)) +# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers +# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/ +# 2) Extract the file from downloaded package, we will get the "trtorch" directory +# 3) Set trtorch_path to that directory +trtorch_path = os.path.abspath("trtorch") + ext_modules = [ cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], - library_dirs=[( - dir_path + "/../../bazel-bin/cpp/api/lib/" - )], + library_dirs=[(trtorch_path + "/lib/")], libraries=["trtorch"], - include_dirs=[dir_path + "/../../"] - ) + include_dirs=[trtorch_path + "/include/trtorch/"]) ] setup( @@ -83,7 +86,9 @@ setup( ) ``` Make sure to include the path for header files in `include_dirs` and the path -for dependent libraries in `library_dirs`. You could also add other compilation +for dependent libraries in `library_dirs`. Generally speaking, you should download +the latest package from [here](https://github.com/NVIDIA/TRTorch/releases), extract +the files, and the set the `trtorch_path` to it. You could also add other compilation flags in cpp_extension if you need. Then, run above python scripts as: ```shell python3 setup.py install --user @@ -140,20 +145,20 @@ if __name__ == "__main__": Run this script, we can get the Tensor before and after ELU operator. ### Example Output ```bash -graph(%self : __torch__.Elu, - %x.1 : Tensor): - %2 : __torch__.torch.nn.modules.activation.ELU = prim::GetAttr[name="elu"](%self) - %4 : Tensor = prim::CallMethod[name="forward"](%2, %x.1) # elu_converter_test.py:13:15 - return (%4) - -tensor([[ 1.3482, 1.9848, -1.0818, -1.3252, 0.2470, 0.7011, 0.3174, -1.8349, - 0.3024, -0.0453, -0.0681, -1.7377, 1.5909, 0.2549, -0.3029, 0.2583, - 0.0242, 2.0748, -0.5454, 0.7137, 1.6688, 0.7108, -0.8681, 0.2486, - -1.3981, 1.0241, 1.2413, 0.2725, 1.4265, 0.9329, 0.4020, -2.6813]]) -tensor([[ 1.3486, 1.9844, -0.6611, -0.7344, 0.2471, 0.7012, 0.3174, -0.8403, - 0.3025, -0.0443, -0.0659, -0.8242, 1.5908, 0.2549, -0.2615, 0.2583, - 0.0242, 2.0742, -0.4204, 0.7139, 1.6689, 0.7109, -0.5801, 0.2485, - -0.7529, 1.0244, 1.2412, 0.2725, 1.4268, 0.9331, 0.4021, -0.9316]], +PyTorch output: + tensor([[ 0.8804, 2.4355, -0.7920, -0.2070, -0.5352, 0.4775, 1.3604, -0.3350, + -0.1802, -0.7563, -0.1758, 0.4067, 1.2510, -0.7100, -0.6221, -0.7207, + -0.1118, 0.9966, 1.6396, -0.1367, -0.5742, 0.5859, 0.8511, 0.6572, + -0.3481, 0.5933, -0.0488, -0.4287, -0.4102, -0.7402, 0.7515, -0.7710]], + device='cuda:0', dtype=torch.float16) +TRTorch output: + tensor([[ 0.8804, 2.4355, -0.7920, -0.2070, -0.5356, 0.4775, 1.3604, -0.3347, + -0.1802, -0.7563, -0.1758, 0.4067, 1.2510, -0.7100, -0.6221, -0.7207, + -0.1117, 0.9966, 1.6396, -0.1368, -0.5747, 0.5859, 0.8511, 0.6572, + -0.3484, 0.5933, -0.0486, -0.4285, -0.4102, -0.7402, 0.7515, -0.7710]], device='cuda:0', dtype=torch.float16) +Maximum differnce between TRTorch and PyTorch: + tensor(0.0005, device='cuda:0', dtype=torch.float16) + ``` diff --git a/examples/elu_converter/elu_converter_test.py b/examples/elu_converter/elu_converter_test.py index 342dbbffe7..d118777fe1 100644 --- a/examples/elu_converter/elu_converter_test.py +++ b/examples/elu_converter/elu_converter_test.py @@ -1,6 +1,7 @@ import torch import trtorch +# After "python3 setup install", you should find this .so file under generated "build" directory torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') @@ -14,14 +15,17 @@ def forward(self, x): return self.elu(x) +def MaxDiff(pytorch_out, trtorch_out): + diff = torch.sub(pytorch_out, trtorch_out) + abs_diff = torch.abs(diff) + max_diff = torch.max(abs_diff) + print("Maximum differnce between TRTorch and PyTorch: \n", max_diff) + + def main(): - data = torch.randn((1, 1, 2, 2)).to("cuda") model = Elu().eval() #.cuda() - # traced_model = torch.jit.trace(model, [data]) scripted_model = torch.jit.script(model) - print(scripted_model.graph) - # torch.jit.save(scripted_model, 'elu.jit') compile_settings = { "input_shapes": [{ "min": [1024, 1, 32, 32], @@ -33,11 +37,13 @@ def main(): } trt_ts_module = trtorch.compile(scripted_model, compile_settings) input_data = torch.randn((1024, 1, 32, 32)) - print(input_data[0, :, :, 0]) input_data = input_data.half().to("cuda") - result = trt_ts_module(input_data) - print(result[0, :, :, 0]) - # torch.jit.save(trt_ts_module, "trt_ts_module.ts") + pytorch_out = model.forward(input_data) + + trtorch_out = trt_ts_module(input_data) + print('PyTorch output: \n', pytorch_out[0, :, :, 0]) + print('TRTorch output: \n', trtorch_out[0, :, :, 0]) + MaxDiff(pytorch_out, trtorch_out) if __name__ == "__main__": diff --git a/examples/elu_converter/setup.py b/examples/elu_converter/setup.py index cc62d4f5cd..af16120928 100644 --- a/examples/elu_converter/setup.py +++ b/examples/elu_converter/setup.py @@ -4,11 +4,17 @@ dir_path = os.path.dirname(os.path.realpath(__file__)) +# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers +# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/ +# 2) Extract the file from downloaded package, we will get the "trtorch" directory +# 3) Set trtorch_path to that directory +trtorch_path = os.path.abspath("trtorch") + ext_modules = [ cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], - library_dirs=[(dir_path + "/../../bazel-bin/cpp/api/lib/")], + library_dirs=[(trtorch_path + "/lib/")], libraries=["trtorch"], - include_dirs=[dir_path + "/../../"]) + include_dirs=[trtorch_path + "/include/trtorch/"]) ] setup( From 13765365b5f62aa3dd2f610c53fa4614da183b0c Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 18 Feb 2021 14:47:53 -0600 Subject: [PATCH 3/5] Updated README.md Signed-off-by: Bo Wang --- examples/README.md | 23 ++++++++++++++++++----- 1 file changed, 18 insertions(+), 5 deletions(-) diff --git a/examples/README.md b/examples/README.md index c608a0748a..17828f509f 100644 --- a/examples/README.md +++ b/examples/README.md @@ -106,9 +106,12 @@ converter and use it in our application: import torch import trtorch +# After "python3 setup install", you should find this .so file under generated "build" directory torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') + class Elu(torch.nn.Module): + def __init__(self): super(Elu, self).__init__() self.elu = torch.nn.ELU() @@ -116,12 +119,18 @@ class Elu(torch.nn.Module): def forward(self, x): return self.elu(x) + +def MaxDiff(pytorch_out, trtorch_out): + diff = torch.sub(pytorch_out, trtorch_out) + abs_diff = torch.abs(diff) + max_diff = torch.max(abs_diff) + print("Maximum differnce between TRTorch and PyTorch: \n", max_diff) + + def main(): - data = torch.randn((1, 1, 2, 2)).to("cuda") model = Elu().eval() #.cuda() scripted_model = torch.jit.script(model) - print(scripted_model.graph) compile_settings = { "input_shapes": [{ "min": [1024, 1, 32, 32], @@ -133,10 +142,14 @@ def main(): } trt_ts_module = trtorch.compile(scripted_model, compile_settings) input_data = torch.randn((1024, 1, 32, 32)) - print(input_data[0, :, :, 0]) input_data = input_data.half().to("cuda") - result = trt_ts_module(input_data) - print(result[0, :, :, 0]) + pytorch_out = model.forward(input_data) + + trtorch_out = trt_ts_module(input_data) + print('PyTorch output: \n', pytorch_out[0, :, :, 0]) + print('TRTorch output: \n', trtorch_out[0, :, :, 0]) + MaxDiff(pytorch_out, trtorch_out) + if __name__ == "__main__": main() From 8ab5465e27c41de40346aa21bc3acd1cfe15f024 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 19 Feb 2021 00:25:57 -0600 Subject: [PATCH 4/5] Lowering pass for SiLU Signed-off-by: Bo Wang --- core/lowering/lowering.cpp | 1 + core/lowering/passes/BUILD | 3 +- core/lowering/passes/passes.h | 1 + .../passes/silu_to_sigmoid_multiplication.cpp | 31 +++++++++++++++++++ tests/core/lowering/BUILD | 4 +++ .../test_silu_to_sigmoid_multiplication.cpp | 29 +++++++++++++++++ 6 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 core/lowering/passes/silu_to_sigmoid_multiplication.cpp create mode 100644 tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 08ce69b72c..aec9ebb8e1 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -49,6 +49,7 @@ void LowerGraph(std::shared_ptr& g) { passes::UnpackLogSoftmax(g); passes::RemoveNOPs(g); passes::AliasOperators(g); + passes::SiluToSigmoidMultipication(g); torch::jit::EliminateDeadCode(g); LOG_GRAPH(*g); } diff --git a/core/lowering/passes/BUILD b/core/lowering/passes/BUILD index e22d0a59b1..f213a2539a 100644 --- a/core/lowering/passes/BUILD +++ b/core/lowering/passes/BUILD @@ -25,7 +25,8 @@ cc_library( "unpack_addmm.cpp", "unpack_batch_norm.cpp", "unpack_log_softmax.cpp", - "op_aliasing.cpp" + "op_aliasing.cpp", + "silu_to_sigmoid_multiplication.cpp" ], deps = [ "//core/util:prelude", diff --git a/core/lowering/passes/passes.h b/core/lowering/passes/passes.h index d6bf083a18..770982f67f 100644 --- a/core/lowering/passes/passes.h +++ b/core/lowering/passes/passes.h @@ -20,6 +20,7 @@ void UnpackAddMM(std::shared_ptr& graph); void UnpackBatchNorm(std::shared_ptr& graph); void UnpackLogSoftmax(std::shared_ptr& graph); void AliasOperators(std::shared_ptr& graph); +void SiluToSigmoidMultipication(std::shared_ptr& graph); } // namespace passes } // namespace lowering diff --git a/core/lowering/passes/silu_to_sigmoid_multiplication.cpp b/core/lowering/passes/silu_to_sigmoid_multiplication.cpp new file mode 100644 index 0000000000..782e659788 --- /dev/null +++ b/core/lowering/passes/silu_to_sigmoid_multiplication.cpp @@ -0,0 +1,31 @@ +#include + +#include "core/util/prelude.h" + +namespace trtorch { +namespace core { +namespace lowering { +namespace passes { + +void SiluToSigmoidMultipication(std::shared_ptr& graph) { + std::string silu_pattern = R"IR( + graph(%x): + %1 : Tensor = aten::silu(%x) + return (%1))IR"; + std::string sigmoid_multiplication_pattern = R"IR( + graph(%x): + %1 : Tensor = aten::sigmoid(%x) + %2 : Tensor = aten::mul(%x, %1) + return (%2))IR"; + ; + + torch::jit::SubgraphRewriter map_silu_to_sigmoid_multiplication; + map_silu_to_sigmoid_multiplication.RegisterRewritePattern(silu_pattern, sigmoid_multiplication_pattern); + map_silu_to_sigmoid_multiplication.runOnGraph(graph); + LOG_GRAPH("Post map silu -> x * sigmoid(x): " << *graph); +} + +} // namespace passes +} // namespace lowering +} // namespace core +} // namespace trtorch diff --git a/tests/core/lowering/BUILD b/tests/core/lowering/BUILD index dd77960b8c..7742a07e06 100644 --- a/tests/core/lowering/BUILD +++ b/tests/core/lowering/BUILD @@ -23,6 +23,10 @@ lowering_test( name = "test_operator_aliasing_pass", ) +lowering_test( + name = "test_silu_to_sigmoid_multiplication", +) + test_suite( name = "lowering_tests", tests = [ diff --git a/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp b/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp new file mode 100644 index 0000000000..fec02711ed --- /dev/null +++ b/tests/core/lowering/test_silu_to_sigmoid_multiplication.cpp @@ -0,0 +1,29 @@ +#include +#include "core/compiler.h" +#include "core/lowering/passes/passes.h" +#include "gtest/gtest.h" +#include "tests/util/util.h" +#include "torch/csrc/jit/ir/irparser.h" +#include "torch/csrc/jit/ir/subgraph_matcher.h" + +TEST(LoweringPasses, RemoveSiluLowersCorrectly) { + std::string source_graph = R"IR( + graph(%x.1 : Tensor): + %2 : Tensor = aten::silu(%x.1) + return (%2))IR"; + std::string target_graph = R"IR( + graph(%x.1): + %2 : Tensor = aten::sigmoid(%x.1) + %3 : Tensor = aten::mul(%x.1, %2) + return (%3))IR"; + + trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH); + auto sg = std::make_shared(); + torch::jit::parseIR(source_graph, &*sg); + trtorch::core::lowering::passes::SiluToSigmoidMultipication(sg); + + auto tg = std::make_shared(); + torch::jit::parseIR(target_graph, &*tg); + + ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty()); +} \ No newline at end of file From 5d0ab489597c4af1f33b6680d3a016668d5c58af Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 19 Feb 2021 00:50:35 -0600 Subject: [PATCH 5/5] delete the changes from elu Signed-off-by: Bo Wang --- examples/README.md | 177 ------------------- examples/elu_converter/elu_converter.cpp | 33 ---- examples/elu_converter/elu_converter_test.py | 50 ------ examples/elu_converter/setup.py | 24 --- 4 files changed, 284 deletions(-) delete mode 100644 examples/README.md delete mode 100644 examples/elu_converter/elu_converter.cpp delete mode 100644 examples/elu_converter/elu_converter_test.py delete mode 100644 examples/elu_converter/setup.py diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index 17828f509f..0000000000 --- a/examples/README.md +++ /dev/null @@ -1,177 +0,0 @@ -# Create a new op in C++, compile it to .so library and load it in Python - -There are some operators in PyTorch library which are not supported in TRTorch. -To support these ops, users can register converters for missing ops. For example, -if we try to compile a graph with a build of TRTorch that doesn't support the -[ELU](https://pytorch.org/docs/stable/generated/torch.nn.ELU.html) operation, -we will get following error: - -> Unable to convert node: %result.2 : Tensor = aten::elu(%x.1, %2, %3, %3) # /home/bowa/.local/lib/python3.6/site-packages/torch/nn/functional.py:1227:17 (conversion.AddLayer) -Schema: aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor) -Converter for aten::elu requested, but no such converter was found. -If you need a converter for this operator, you can try implementing one yourself -or request a converter: https://www.github.com/NVIDIA/TRTorch/issues - -## Writing Converter in C++ -We can register a converter for this operator in our application. You can find more -information on all the details of writing converters in the contributors documentation -([Writing Converters](https://nvidia.github.io/TRTorch/contributors/writing_converters.html)). -Once we are clear about these rules and writing patterns, we can create a seperate new C++ source file as: - -```c++ -#include "core/conversion/converters/converters.h" -#include "core/util/prelude.h" - -namespace trtorch { -namespace core { -namespace conversion { -namespace converters { -namespace impl { -namespace { - -auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto alpha = args[1].unwrapToDouble(); - - auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU); - TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu"); - - new_layer->setAlpha(alpha); - new_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); - - LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); - return true; - }}); - -} // namespace -} // namespace impl -} // namespace converters -} // namespace conversion -} // namespace core -} // namespace trtorch -``` - -## Generate `.so` library -To use this converter in Python, it is recommended to use PyTorch's -[C++/CUDA Extension](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions). -We give an example here about how to wrap the converter into a `.so` -library so that you can load it to use in Python applicaton. -```python -import os -from setuptools import setup, Extension -from torch.utils import cpp_extension - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers -# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/ -# 2) Extract the file from downloaded package, we will get the "trtorch" directory -# 3) Set trtorch_path to that directory -trtorch_path = os.path.abspath("trtorch") - -ext_modules = [ - cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], - library_dirs=[(trtorch_path + "/lib/")], - libraries=["trtorch"], - include_dirs=[trtorch_path + "/include/trtorch/"]) -] - -setup( - name='elu_converter', - ext_modules=ext_modules, - cmdclass={'build_ext': cpp_extension.BuildExtension}, -) -``` -Make sure to include the path for header files in `include_dirs` and the path -for dependent libraries in `library_dirs`. Generally speaking, you should download -the latest package from [here](https://github.com/NVIDIA/TRTorch/releases), extract -the files, and the set the `trtorch_path` to it. You could also add other compilation -flags in cpp_extension if you need. Then, run above python scripts as: -```shell -python3 setup.py install --user -``` -You should see the output similar to the contents indicated [here](https://pytorch.org/tutorials/advanced/cpp_extension.html#custom-c-and-cuda-extensions) after running -`python setup.py install`. You should find a couple of new folders generated -by the command above. In build folder, you can find the generated `.so` library, -which could be loaded in our Python application. - -## Load `.so` in Python Application -With the new generated library, TRTorch now support the new developed converter. -We use `torch.ops.load_library` to load `.so`. For example, we could load the ELU -converter and use it in our application: -```python -import torch -import trtorch - -# After "python3 setup install", you should find this .so file under generated "build" directory -torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') - - -class Elu(torch.nn.Module): - - def __init__(self): - super(Elu, self).__init__() - self.elu = torch.nn.ELU() - - def forward(self, x): - return self.elu(x) - - -def MaxDiff(pytorch_out, trtorch_out): - diff = torch.sub(pytorch_out, trtorch_out) - abs_diff = torch.abs(diff) - max_diff = torch.max(abs_diff) - print("Maximum differnce between TRTorch and PyTorch: \n", max_diff) - - -def main(): - model = Elu().eval() #.cuda() - - scripted_model = torch.jit.script(model) - compile_settings = { - "input_shapes": [{ - "min": [1024, 1, 32, 32], - "opt": [1024, 1, 33, 33], - "max": [1024, 1, 34, 34], - }], - "op_precision": - torch.half # Run with FP16 - } - trt_ts_module = trtorch.compile(scripted_model, compile_settings) - input_data = torch.randn((1024, 1, 32, 32)) - input_data = input_data.half().to("cuda") - pytorch_out = model.forward(input_data) - - trtorch_out = trt_ts_module(input_data) - print('PyTorch output: \n', pytorch_out[0, :, :, 0]) - print('TRTorch output: \n', trtorch_out[0, :, :, 0]) - MaxDiff(pytorch_out, trtorch_out) - - -if __name__ == "__main__": - main() - -``` -Run this script, we can get the Tensor before and after ELU operator. -### Example Output -```bash -PyTorch output: - tensor([[ 0.8804, 2.4355, -0.7920, -0.2070, -0.5352, 0.4775, 1.3604, -0.3350, - -0.1802, -0.7563, -0.1758, 0.4067, 1.2510, -0.7100, -0.6221, -0.7207, - -0.1118, 0.9966, 1.6396, -0.1367, -0.5742, 0.5859, 0.8511, 0.6572, - -0.3481, 0.5933, -0.0488, -0.4287, -0.4102, -0.7402, 0.7515, -0.7710]], - device='cuda:0', dtype=torch.float16) -TRTorch output: - tensor([[ 0.8804, 2.4355, -0.7920, -0.2070, -0.5356, 0.4775, 1.3604, -0.3347, - -0.1802, -0.7563, -0.1758, 0.4067, 1.2510, -0.7100, -0.6221, -0.7207, - -0.1117, 0.9966, 1.6396, -0.1368, -0.5747, 0.5859, 0.8511, 0.6572, - -0.3484, 0.5933, -0.0486, -0.4285, -0.4102, -0.7402, 0.7515, -0.7710]], - device='cuda:0', dtype=torch.float16) -Maximum differnce between TRTorch and PyTorch: - tensor(0.0005, device='cuda:0', dtype=torch.float16) - - -``` diff --git a/examples/elu_converter/elu_converter.cpp b/examples/elu_converter/elu_converter.cpp deleted file mode 100644 index c10b49f007..0000000000 --- a/examples/elu_converter/elu_converter.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "core/conversion/converters/converters.h" -#include "core/util/prelude.h" - -namespace trtorch { -namespace core { -namespace conversion { -namespace converters { -namespace impl { -namespace { - -auto acthardtanh TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto alpha = args[1].unwrapToDouble(); - - auto new_layer = ctx->net->addActivation(*in, nvinfer1::ActivationType::kELU); - TRTORCH_CHECK(new_layer, "Unable to create layer for aten::elu"); - - new_layer->setAlpha(alpha); - new_layer->setName(util::node_info(n).c_str()); - auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0)); - - LOG_DEBUG("Output shape: " << out_tensor->getDimensions()); - return true; - }}); - -} // namespace -} // namespace impl -} // namespace converters -} // namespace conversion -} // namespace core -} // namespace trtorch diff --git a/examples/elu_converter/elu_converter_test.py b/examples/elu_converter/elu_converter_test.py deleted file mode 100644 index d118777fe1..0000000000 --- a/examples/elu_converter/elu_converter_test.py +++ /dev/null @@ -1,50 +0,0 @@ -import torch -import trtorch - -# After "python3 setup install", you should find this .so file under generated "build" directory -torch.ops.load_library('./build/lib.linux-x86_64-3.6/elu_converter.cpython-36m-x86_64-linux-gnu.so') - - -class Elu(torch.nn.Module): - - def __init__(self): - super(Elu, self).__init__() - self.elu = torch.nn.ELU() - - def forward(self, x): - return self.elu(x) - - -def MaxDiff(pytorch_out, trtorch_out): - diff = torch.sub(pytorch_out, trtorch_out) - abs_diff = torch.abs(diff) - max_diff = torch.max(abs_diff) - print("Maximum differnce between TRTorch and PyTorch: \n", max_diff) - - -def main(): - model = Elu().eval() #.cuda() - - scripted_model = torch.jit.script(model) - compile_settings = { - "input_shapes": [{ - "min": [1024, 1, 32, 32], - "opt": [1024, 1, 33, 33], - "max": [1024, 1, 34, 34], - }], - "op_precision": - torch.half # Run with FP16 - } - trt_ts_module = trtorch.compile(scripted_model, compile_settings) - input_data = torch.randn((1024, 1, 32, 32)) - input_data = input_data.half().to("cuda") - pytorch_out = model.forward(input_data) - - trtorch_out = trt_ts_module(input_data) - print('PyTorch output: \n', pytorch_out[0, :, :, 0]) - print('TRTorch output: \n', trtorch_out[0, :, :, 0]) - MaxDiff(pytorch_out, trtorch_out) - - -if __name__ == "__main__": - main() diff --git a/examples/elu_converter/setup.py b/examples/elu_converter/setup.py deleted file mode 100644 index af16120928..0000000000 --- a/examples/elu_converter/setup.py +++ /dev/null @@ -1,24 +0,0 @@ -import os -from setuptools import setup, Extension -from torch.utils import cpp_extension - -dir_path = os.path.dirname(os.path.realpath(__file__)) - -# library_dirs should point to the libtrtorch.so, include_dirs should point to the dir that include the headers -# 1) download the latest package from https://github.com/NVIDIA/TRTorch/releases/ -# 2) Extract the file from downloaded package, we will get the "trtorch" directory -# 3) Set trtorch_path to that directory -trtorch_path = os.path.abspath("trtorch") - -ext_modules = [ - cpp_extension.CUDAExtension('elu_converter', ['elu_converter.cpp'], - library_dirs=[(trtorch_path + "/lib/")], - libraries=["trtorch"], - include_dirs=[trtorch_path + "/include/trtorch/"]) -] - -setup( - name='elu_converter', - ext_modules=ext_modules, - cmdclass={'build_ext': cpp_extension.BuildExtension}, -)