diff --git a/docsrc/requirements.txt b/docsrc/requirements.txt index e2e90ad307..34c6e4d8b5 100644 --- a/docsrc/requirements.txt +++ b/docsrc/requirements.txt @@ -2,5 +2,5 @@ sphinx==3.1.2 breathe==4.19.2 exhale sphinx_rtd_theme==0.4.3 -sphinx-material==0.0.30 +sphinx-material==0.0.35 nbsphinx==0.8.6 \ No newline at end of file diff --git a/py/trtorch/Device.py b/py/trtorch/Device.py index 3e408bd951..41a1308518 100644 --- a/py/trtorch/Device.py +++ b/py/trtorch/Device.py @@ -105,6 +105,11 @@ def _from_torch_device(cls, torch_dev: torch.device): gpu_id = torch_dev.index return cls(gpu_id=gpu_id) + @classmethod + def _current_device(cls): + dev = trtorch._C._get_current_device() + return cls(gpu_id=dev.gpu_id) + @staticmethod def _parse_device_str(s): s = s.lower() diff --git a/py/trtorch/_compile_spec.py b/py/trtorch/_compile_spec.py index d21ec5bdd1..a8e7d9a562 100644 --- a/py/trtorch/_compile_spec.py +++ b/py/trtorch/_compile_spec.py @@ -4,6 +4,7 @@ from trtorch import _types from trtorch.Input import Input from trtorch.Device import Device +from trtorch._types import EngineCapability import warnings @@ -246,63 +247,80 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec: return info -def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.CompileSpec: - """ - Utility to create a formated spec dictionary for using the PyTorch TensorRT backend - - Args: - compile_spec (dict): Compilation settings including operating precision, target device, etc. - One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs - to the graph as well as expect types and formats for those inputs. All other keys are optional. - Entries for each method to be compiled. - - Note: Partial compilation of TorchScript modules is not supported through the PyTorch TensorRT backend - If you need this feature, use trtorch.compile to compile your module. Usage of the resulting module is - as if you were using the TensorRT integration. - - .. code-block:: py - - CompileSpec = { - "forward" : trtorch.TensorRTCompileSpec({ - "inputs": [ - trtorch.Input((1, 3, 224, 224)), # Static input shape for input #1 - trtorch.Input( - min_shape=1, 3, 224, 224), - opt_shape=(1, 3, 512, 512), - max_shape=(1, 3, 1024, 1024), - dtype=torch.int32 - format=torch.channel_last - ) # Dynamic input shape for input #2 - ], - "device": { - "device_type": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA) - "gpu_id": 0, # Target gpu id to run engine (Use Xavier as gpu id for DLA) - "dla_core": 0, # (DLA only) Target dla core id to run engine - "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU - }, - "enabled_precisions": {torch.half}, # Operating precision set to FP16 - "sparse_weights": Enable sparsity for convolution and fully connected layers. - "disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas - "refit": False, # enable refit - "debug": False, # enable debuggable engine - "strict_types": False, # kernels should strictly run in operating precision - "capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels - "num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels - "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels - "workspace_size": 0, # Maximum size of workspace given to TensorRT - "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set) - "truncate_long_and_double": False, # Truncate long and double into int and float - }) - } - - Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using +def TensorRTCompileSpec(inputs=[], + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=False, + strict_types=False, + capability=EngineCapability.default, + num_min_timing_iters=2, + num_avg_timing_iters=1, + workspace_size=0, + max_batch_size=0, + truncate_long_and_double=False, + calibrator=None) -> torch.classes.tensorrt.CompileSpec: + """Utility to create a formated spec dictionary for using the PyTorch TensorRT backend + + Keyword Args: + inputs (List[Union(trtorch.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum - to select device type. - - Returns: + to select device type. :: + + input=[ + trtorch.Input((1, 3, 224, 224)), # Static NCHW input shape for input #1 + trtorch.Input( + min_shape=(1, 224, 224, 3), + opt_shape=(1, 512, 512, 3), + max_shape=(1, 1024, 1024, 3), + dtype=torch.int32 + format=torch.channel_last + ), # Dynamic input shape for input #2 + torch.randn((1, 3, 224, 244)) # Use an example tensor and let trtorch infer settings + ] + + device (Union(trtorch.Device, torch.device, dict)): Target device for TensorRT engines to run on :: + + device=trtorch.Device("dla:1", allow_gpu_fallback=True) + + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + sparse_weights (bool): Enable sparsity for convolution and fully connected layers. + enabled_precision (Set(Union(torch.dtype, trtorch.dtype))): The set of datatypes that TensorRT can use when selecting kernels + refit (bool): Enable refitting + debug (bool): Enable debuggable engine + strict_types (bool): Kernels should strictly run in a particular operating precision. Enabled precision should only have one type in the set + capability (trtorch.EngineCapability): Restrict kernel selection to safe gpu kernels or safe dla kernels + num_min_timing_iters (int): Number of minimization timing iterations used to select kernels + num_avg_timing_iters (int): Number of averaging timing iterations used to select kernels + workspace_size (int): Maximum size of workspace given to TensorRT + max_batch_size (int): Maximum batch size (must be >= 1 to be set, 0 means not set) + truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 + calibrator (Union(trtorch._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration + + Returns: torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt`` """ + compile_spec = { + "inputs": inputs, + "device": device, + "disable_tf32": disable_tf32, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas + "sparse_weights": sparse_weights, #Enable sparsity for convolution and fully connected layers. + "enabled_precisions": enabled_precisions, # Enabling FP16 kernels + "refit": refit, # enable refit + "debug": debug, # enable debuggable engine + "strict_types": strict_types, # kernels should strictly run in operating precision + "capability": capability, # Restrict kernel selection to safe gpu kernels or safe dla kernels + "num_min_timing_iters": num_min_timing_iters, # Number of minimization timing iterations used to select kernels + "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels + "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT + "max_batch_size": max_batch_size, # Maximum batch size (must be >= 1 to be set, 0 means not set) + "calibrator": calibrator, + "truncate_long_and_double": truncate_long_and_double + } + parsed_spec = _parse_compile_spec(compile_spec) backend_spec = torch.classes.tensorrt.CompileSpec() diff --git a/py/trtorch/_compiler.py b/py/trtorch/_compiler.py index 78aad5912e..52f10751ac 100644 --- a/py/trtorch/_compiler.py +++ b/py/trtorch/_compiler.py @@ -4,7 +4,7 @@ import trtorch._C from trtorch._types import EngineCapability -from trtorch._compile_spec import _parse_compile_spec +from trtorch._compile_spec import _parse_compile_spec, _parse_device from trtorch._version import __version__ from trtorch.Device import Device from types import FunctionType @@ -25,7 +25,10 @@ def compile(module: torch.jit.ScriptModule, max_batch_size=0, calibrator=None, truncate_long_and_double=False, - torch_fallback={"enabled": False}) -> torch.jit.ScriptModule: + require_full_compilation=True, + min_block_size=3, + torch_executed_ops=[], + torch_executed_modules=[]) -> torch.jit.ScriptModule: """Compile a TorchScript module for NVIDIA GPUs using TensorRT Takes a existing TorchScript module and a set of settings to configure the compiler @@ -38,7 +41,7 @@ def compile(module: torch.jit.ScriptModule, ``torch.nn.Module`` Keyword Arguments: - inputs (List[Union(trtorch.Input, torch.Tensor)]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + inputs (List[Union(trtorch.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type. :: @@ -57,7 +60,7 @@ def compile(module: torch.jit.ScriptModule, device (Union(trtorch.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=trtorch.Device("dla:1", allow_gpu_fallback=True) - + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, trtorch.dtype))): The set of datatypes that TensorRT can use when selecting kernels @@ -71,19 +74,10 @@ def compile(module: torch.jit.ScriptModule, max_batch_size (int): Maximum batch size (must be >= 1 to be set, 0 means not set) truncate_long_and_double (bool): Truncate weights provided in int64 or double (float64) to int32 and float32 calibrator (Union(trtorch._C.IInt8Calibrator, tensorrt.IInt8Calibrator)): Calibrator object which will provide data to the PTQ system for INT8 Calibration - torch_fallback (dict): Settings related to partial compilation. Partial compilation will run any unsupported operations and any operators or submodules specified by the user in PyTorch :: - - torch_fallback={ - "enabled": True, - "force_fallback_ops": [ - "aten::max_pool2d" # List of specific ops to require running in PyTorch - ], - "force_fallback_modules": [ - "mypymod.mytorchmod" # List of specific torch modules to require running in PyTorch - ], - "min_block_size": 3 # Minimum number of ops an engine must incapsulate to be run in TensorRT - } - + require_full_compilation (bool): Require modules to be compiled end to end or return an error as opposed to returning a hybrid graph where operations that cannot be run in TensorRT are run in PyTorch + min_block_size (int): The minimum number of contiguous TensorRT convertable operations in order to run a set of operations in TensorRT + torch_executed_ops (List[str]): List of aten operators that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True + torch_executed_modules (List[str]): List of modules that must be run in PyTorch. An error will be thrown if this list is not empty but ``require_full_compilation`` is True Returns: torch.jit.ScriptModule: Compiled TorchScript Module, when run it will execute via TensorRT @@ -93,6 +87,11 @@ def compile(module: torch.jit.ScriptModule, raise TypeError( "torch.jit.ScriptFunction currently is not directly supported, wrap the function in a module to compile") + if require_full_compilation and (len(torch_executed_modules) > 0 or len(torch_executed_ops) > 0): + raise ValueError( + "require_full_compilation is enabled however the list of modules and ops to run in torch is not empty. Found: torch_executed_ops: " + torch_executed_ops + ", torch_executed_modules: " + torch_executed_modules + ) + spec = { "inputs": inputs, "device": device, @@ -109,7 +108,11 @@ def compile(module: torch.jit.ScriptModule, "max_batch_size": max_batch_size, # Maximum batch size (must be >= 1 to be set, 0 means not set) "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double, - "torch_fallback": torch_fallback + "torch_fallback": { + "enabled": not require_full_compilation, + "force_fallback_ops": torch_executed_ops, + "force_fallback_modules": torch_executed_modules + } } compiled_cpp_mod = trtorch._C.compile_graph(module._c, _parse_compile_spec(spec)) @@ -117,21 +120,23 @@ def compile(module: torch.jit.ScriptModule, return compiled_module -def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: str, inputs=[], - device=Device._current_device(), - disable_tf32=False, - sparse_weights=False, - enabled_precisions=set(), - refit=False, - debug=False, - strict_types=False, - capability=EngineCapability.default, - num_min_timing_iters=2, - num_avg_timing_iters=1, - workspace_size=0, - max_batch_size=0, - truncate_long_and_double=False, - calibrator=None) -> str: +def convert_method_to_trt_engine(module: torch.jit.ScriptModule, + method_name: str, + inputs=[], + device=Device._current_device(), + disable_tf32=False, + sparse_weights=False, + enabled_precisions=set(), + refit=False, + debug=False, + strict_types=False, + capability=EngineCapability.default, + num_min_timing_iters=2, + num_avg_timing_iters=1, + workspace_size=0, + max_batch_size=0, + truncate_long_and_double=False, + calibrator=None) -> str: """Convert a TorchScript module method to a serialized TensorRT engine Converts a specified method of a module to a serialized TensorRT engine given a dictionary of conversion settings @@ -141,8 +146,8 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st ``torch.nn.Module`` method_name (str): Name of method to convert - Keyword Args: - inputs (List[Union(trtorch.Input, torch.Tensor)]): List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using + Keyword Args: + inputs (List[Union(trtorch.Input, torch.Tensor)]): **Required** List of specifications of input shape, dtype and memory layout for inputs to the module. This argument is required. Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum to select device type. :: @@ -161,7 +166,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st device (Union(trtorch.Device, torch.device, dict)): Target device for TensorRT engines to run on :: device=trtorch.Device("dla:1", allow_gpu_fallback=True) - + disable_tf32 (bool): Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas sparse_weights (bool): Enable sparsity for convolution and fully connected layers. enabled_precision (Set(Union(torch.dtype, trtorch.dtype))): The set of datatypes that TensorRT can use when selecting kernels @@ -197,7 +202,6 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st "num_avg_timing_iters": num_avg_timing_iters, # Number of averaging timing iterations used to select kernels "workspace_size": workspace_size, # Maximum size of workspace given to TensorRT "max_batch_size": max_batch_size, # Maximum batch size (must be >= 1 to be set, 0 means not set) - "torch_fallback": {"enabled": False}, "calibrator": calibrator, "truncate_long_and_double": truncate_long_and_double } @@ -205,7 +209,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec)) -def embed_engine_in_new_module(serialized_engine: bytes, device: Device) -> torch.jit.ScriptModule: +def embed_engine_in_new_module(serialized_engine: bytes, device=Device._current_device()) -> torch.jit.ScriptModule: """Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module. @@ -215,13 +219,16 @@ def embed_engine_in_new_module(serialized_engine: bytes, device: Device) -> torc Module can be save with engine embedded with torch.jit.save and moved / loaded according to TRTorch portability rules - Args: + Arguments: serialized_engine (bytes): Serialized TensorRT engine from either TRTorch or TensorRT APIs + Keyword Arguments: + device (Union(trtorch.Device, torch.device, dict)): Target device to run engine on. Must be compatible with engine provided. Default: Current active device + Returns: torch.jit.ScriptModule: New TorchScript module with engine embedded """ - cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine, device._to_internal()) + cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine, _parse_device(device)) return torch.jit._recursive.wrap_cpp_module(cpp_mod) @@ -231,7 +238,7 @@ def check_method_op_support(module: torch.jit.ScriptModule, method_name: str) -> Checks if a method of a TorchScript module can be compiled by TRTorch, if not, a list of operators that are not supported are printed out and the function returns false, else true. - Args: + Arguments: module (torch.jit.ScriptModule): Source module, a result of tracing or scripting a PyTorch ``torch.nn.Module`` method_name (str): Name of method to check diff --git a/py/trtorch/csrc/tensorrt_backend.cpp b/py/trtorch/csrc/tensorrt_backend.cpp index 1e63c65bd4..13216b6aa6 100644 --- a/py/trtorch/csrc/tensorrt_backend.cpp +++ b/py/trtorch/csrc/tensorrt_backend.cpp @@ -27,7 +27,7 @@ c10::impl::GenericDict TensorRTBackend::compile(c10::IValue mod_val, c10::impl:: auto g = graph_and_ivals.first; auto params = graph_and_ivals.second; - auto named_params = core::conversion::get_named_params(g->inputs(), params); + auto named_params = core::ir::get_static_params(g->inputs(), params); auto convert_cfg = std::move(cfg.convert_info); auto device_spec = convert_cfg.engine_settings.device; diff --git a/py/trtorch/csrc/tensorrt_classes.cpp b/py/trtorch/csrc/tensorrt_classes.cpp index bbab4942fb..cf037575a1 100644 --- a/py/trtorch/csrc/tensorrt_classes.cpp +++ b/py/trtorch/csrc/tensorrt_classes.cpp @@ -68,9 +68,9 @@ std::string to_str(TensorFormat value) { core::ir::Input Input::toInternalInput() { if (!input_is_dynamic) { - return core::ir::Input(opt, toTRTDataType(dtype), toTRTTensorFormat(format)); + return core::ir::Input(opt, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } else { - return core::ir::Input(min, opt, max, toTRTDataType(dtype), toTRTTensorFormat(format)); + return core::ir::Input(min, opt, max, toTRTDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); } } diff --git a/py/trtorch/csrc/trtorch_py.cpp b/py/trtorch/csrc/trtorch_py.cpp index 2ec9e38754..e8d3c9e696 100644 --- a/py/trtorch/csrc/trtorch_py.cpp +++ b/py/trtorch/csrc/trtorch_py.cpp @@ -321,7 +321,6 @@ PYBIND11_MODULE(_C, m) { m.def("set_device", &trtorch::pyapi::set_device, "Set CUDA device id"); m.def("_get_current_device", &trtorch::pyapi::get_current_device, "Get the current active CUDA device"); - py::enum_(m, "LogLevel", py::arithmetic()) .value("INTERNAL_ERROR", core::util::logging::LogLevel::kINTERNAL_ERROR) .value("ERROR", core::util::logging::LogLevel::kERROR) diff --git a/tests/py/test_api.py b/tests/py/test_api.py index 76ff13f2c1..c28cdaa27b 100644 --- a/tests/py/test_api.py +++ b/tests/py/test_api.py @@ -53,6 +53,12 @@ def test_device(self): same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() self.assertTrue(same < 2e-2) + def test_default_device(self): + compile_spec = {"inputs": [self.input], "enabled_precisions": {torch.float}} + + trt_mod = trtorch.compile(self.scripted_model, **compile_spec) + same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max() + self.assertTrue(same < 2e-2) def test_compile_script_from_dict(self): compile_spec = { @@ -133,11 +139,9 @@ def test_compile_script(self): "allow_gpu_fallback": False, "disable_tf32": False }, - "torch_fallback": { - "enabled": True, - "forced_fallback_ops": ["aten::max_pool2d"], - "min_block_size": 1 - } + "require_full_compilation": False, + "torch_executed_ops": ["aten::max_pool2d"], + "min_block_size": 1 } trt_mod = trtorch.compile(self.scripted_model, **compile_spec) @@ -161,11 +165,9 @@ def test_compile_script(self): "allow_gpu_fallback": False, "disable_tf32": False }, - "torch_fallback": { - "enabled": True, - "forced_fallback_modules": ["torchvision.models.resnet.BasicBlock"], - "min_block_size": 1 - } + "require_full_compilation": False, + "torch_executed_modules": ["torchvision.models.resnet.BasicBlock"], + "min_block_size": 1 } trt_mod = trtorch.compile(self.scripted_model, **compile_spec)