-
Notifications
You must be signed in to change notification settings - Fork 372
Python API Cleanup #452
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Python API Cleanup #452
Conversation
We cant run partial compilation on modules from the to_backend API because we are expected to simply return a handle to a TRT engine vs return a full graph. Therefore we cannot do graph stitching. Now an exception will be thrown if someone tries to use fallback and to_backend directing them towards trtorch.compile Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
Signed-off-by: Naren Dasan <naren@narendasan.com> Signed-off-by: Naren Dasan <narens@nvidia.com>
4222f01
to
9bf2456
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -276,7 +276,9 @@
d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
if parsed_spec.torch_fallback.enabled:
- raise RuntimeError("Partial module compilation is not currently supported via the PyTorch TensorRT backend. If you need partial compilation, use trtorch.compile")
+ raise RuntimeError(
+ "Partial module compilation is not currently supported via the PyTorch TensorRT backend. If you need partial compilation, use trtorch.compile"
+ )
torch_fallback = torch.classes.tensorrt._TorchFallback()
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index db5e522..8733af6 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -5,23 +5,22 @@ namespace backend {
namespace {
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
- (registry).def("_set_" #field_name, &class_name::set_##field_name); \
+ (registry).def("_set_" #field_name, &class_name::set_##field_name); \
(registry).def("_get_" #field_name, &class_name::get_##field_name);
void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
- torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::InputRange::to_str);
+ torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::InputRange::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
- static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
- torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::Device::to_str);
+ static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::Device::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
@@ -29,9 +28,9 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
- torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::TorchFallback::to_str);
+ torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::TorchFallback::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
diff --git a/workspace/py/trtorch/csrc/tensorrt_classes.h b/tmp/changes.txt
index 4dd7a8f..50b1114 100644
--- a/workspace/py/trtorch/csrc/tensorrt_classes.h
+++ b/tmp/changes.txt
@@ -93,7 +93,6 @@ struct TorchFallback : torch::CustomClassHolder {
std::string to_str();
};
-
enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to C++ style guidelines:
diff --git a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp b/tmp/changes.txt
index db5e522..8733af6 100644
--- a/workspace/py/trtorch/csrc/register_tensorrt_classes.cpp
+++ b/tmp/changes.txt
@@ -5,23 +5,22 @@ namespace backend {
namespace {
#define ADD_FIELD_GET_SET_REGISTRATION(registry, class_name, field_name) \
- (registry).def("_set_" #field_name, &class_name::set_##field_name); \
+ (registry).def("_set_" #field_name, &class_name::set_##field_name); \
(registry).def("_get_" #field_name, &class_name::get_##field_name);
void RegisterTRTCompileSpec() {
static auto TRTORCH_UNUSED TRTInputRangeTSRegistration =
- torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::InputRange::to_str);
+ torch::class_<trtorch::pyapi::InputRange>("tensorrt", "_InputRange")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::InputRange::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, min);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, opt);
ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, trtorch::pyapi::InputRange, max);
- static auto TRTORCH_UNUSED TRTDeviceTSRegistration =
- torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::Device::to_str);
+ static auto TRTORCH_UNUSED TRTDeviceTSRegistration = torch::class_<trtorch::pyapi::Device>("tensorrt", "_Device")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::Device::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, device_type);
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, gpu_id);
@@ -29,9 +28,9 @@ void RegisterTRTCompileSpec() {
ADD_FIELD_GET_SET_REGISTRATION(TRTDeviceTSRegistration, trtorch::pyapi::Device, allow_gpu_fallback);
static auto TRTORCH_UNUSED TRTFallbackTSRegistration =
- torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
- .def(torch::init<>())
- .def("__str__", &trtorch::pyapi::TorchFallback::to_str);
+ torch::class_<trtorch::pyapi::TorchFallback>("tensorrt", "_TorchFallback")
+ .def(torch::init<>())
+ .def("__str__", &trtorch::pyapi::TorchFallback::to_str);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, enabled);
ADD_FIELD_GET_SET_REGISTRATION(TRTFallbackTSRegistration, trtorch::pyapi::TorchFallback, min_block_size);
diff --git a/workspace/py/trtorch/csrc/tensorrt_classes.h b/tmp/changes.txt
index 4dd7a8f..50b1114 100644
--- a/workspace/py/trtorch/csrc/tensorrt_classes.h
+++ b/tmp/changes.txt
@@ -93,7 +93,6 @@ struct TorchFallback : torch::CustomClassHolder {
std::string to_str();
};
-
enum class EngineCapability : int8_t {
kDEFAULT,
kSAFE_GPU,
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
Reformatting /workspace/cpp/ptq/training/vgg16/export_ckpt.py
Reformatting /workspace/cpp/ptq/training/vgg16/vgg16.py
Reformatting /workspace/cpp/ptq/training/vgg16/main.py
--- /workspace/py/trtorch/_compile_spec.py (original)
+++ /workspace/py/trtorch/_compile_spec.py (reformatted)
@@ -276,7 +276,9 @@
d._set_allow_gpu_fallback(parsed_spec.device.allow_gpu_fallback)
if parsed_spec.torch_fallback.enabled:
- raise RuntimeError("Partial module compilation is not currently supported via the PyTorch TensorRT backend. If you need partial compilation, use trtorch.compile")
+ raise RuntimeError(
+ "Partial module compilation is not currently supported via the PyTorch TensorRT backend. If you need partial compilation, use trtorch.compile"
+ )
torch_fallback = torch.classes.tensorrt._TorchFallback()
torch_fallback._set_enabled(parsed_spec.torch_fallback.enabled)
Reformatting /workspace/py/trtorch/ptq.py
Reformatting /workspace/py/trtorch/_compiler.py
Reformatting /workspace/py/trtorch/logging.py
Reformatting /workspace/py/trtorch/_compile_spec.py
Reformatting /workspace/py/trtorch/__init__.py
Reformatting /workspace/py/trtorch/_types.py
Reformatting /workspace/py/setup.py
ERROR: Some files do not conform to style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
Gating the partial compilation feature from the to_backend api because it does not currently work in a way that is compatible with graph stitching. Adding documentation and QoL improvements to the Python API
Type of change
Please delete options that are not relevant and/or add your own.
Checklist: