Skip to content

TRTorch v0.3.0

Compare
Choose a tag to compare
@narendasan narendasan released this 14 May 00:55

TRTorch v0.3.0

Support for PyTorch 1.8.x (by default 1.8.1), Introducing Plugin Library, PTQ from Python, Arbitrary TRT engine embedding, Preview Release of Partial Compilation, New Converters, Bug Fixes

This is the third beta release of TRTorch, targeting PyTorch 1.8.x, CUDA 11.1 (on x86_64), TensorRT 7.2, cuDNN 8. TRTorch 0.3.0 binary releases target PyTorch 1.8.1 specifically, these builds are not compatible with 1.8.0, though the source code remains compatible with any PyTorch 1.8.x version. On aarch64 TRTorch targets JetPack 4.5.x. This release introduces libtrtorch_plugins.so. This library is a portable distribution of all TensorRT plugins used in TRTorch. The intended usecase is to support TRTorch programs that utilize TensorRT plugins deployed on systems with only the runtime library available or in the case that TRTorch was used to create a TensorRT engine to be run outside the TRTorch runtime, which makes uses of TRTorch plugins. An example on how to use this library can be found here: https://www.github.com/NVIDIA/TRTorch/tree/v0.3.0/examples/sample_rt_app. TRTorch 0.3.0 also now allows users to repurpose PyTorch Dataloaders to do post training quantization in Python similar to the workflow supported in C++ currently. It also introduces a new API to wrap arbitrary TensorRT engines in a PyTorch Module wrapper, making the serializable by torch.jit.save and completely compatible with other PyTorch modules. Finally, TRTorch 0.3.0 also includes a preview of the new partial compilation capability of the TRTorch compiler. With this feature, users can now instruct TRTorch to keep operations that are not supported but TRTorch/TensorRT in PyTorch. Partial compilation should be considered alpha stability and we are seeking feedback on bugs, pain points and feature requests surrounding using this feature.

Dependencies:

- Bazel 4.0.0
- LibTorch 1.8.1 (on x86_64), 1.8.0 (on aarch64)
- CUDA 11.1 (on x86_64, by default , newer CUDA 11 supported with compatible PyTorch Build), 10.2 (on aarch64)
- cuDNN 8.1.1
- TensorRT 7.2.3.4

0.3.0 (2021-05-13)

Bug Fixes

  • //plugins: Readding cuBLAS BUILD to allow linking of libnvinfer_plugin on Jetson (a8008f4)

  • //tests/../concat: Concat test fix (2432fb8)

  • //tests/core/partitioning: Fixing some issues with the partition (ff89059)

  • erase the repetitive nodes in dependency analysis (80b1038)

  • fix a typo for debug (c823ebd)

  • fix typo bug (e491bb5)

  • aten::linear: Fixes new issues in 1.8 that cause script based (c5057f8)

  • register the torch_fallback attribute in Python API (8b7919f)

  • support expand/repeat with IValue type input (a4882c6)

  • support shape inference for add_, support non-tensor arguments for segmented graphs (46950bb)

  • feat!: Updating versions of CUDA, cuDNN, TensorRT and PyTorch (71c4dcb)

  • feat(WORKSPACE)!: Updating PyTorch version to 1.8.1 (c9aa99a)

Features

  • //.github: Linter throws 1 when there needs to be style changes to (a39dea7)
  • //core: New API to register arbitrary TRT engines in TorchScript (3ec836e)
  • //core/conversion/conversionctx: Adding logging for truncated (96245ee)
  • //core/partitioing: Adding ostream for Partition Info (b3589c5)
  • //core/partitioning: Add an ostream implementation for (ee536b6)
  • //core/partitioning: Refactor top level partitioning API, fix a bug with (abc63f6)
  • //core/plugins: Gating plugin logging based on global config (1d5a088)
  • added user level API for fallback (f4c29b4)
  • allow users to set fallback block size and ops (6d3064a)
  • insert nodes by dependencies for nonTensor inputs/outputs (4e32eff)
  • support aten::arange converter (014e381)
  • support aten::transpose with negative dim (4a1d2f3)
  • support Int/Bool and other constants' inputs/outputs for TensorRT segments (54e407e)
  • support prim::Param for fallback inputs (ec2bbf2)
  • support prim::Param for input type after refactor (3cebe97)
  • support Python APIs for Automatic Fallback (100b090)
  • support the case when the injected node is not supported in dependency analysis (c67d8f6)
  • support truncate long/double to int/float with option (740eb54)
  • Try to submit review before exit (9a9d7f0)
  • update truncate long/double python api (69e49e8)
  • //docker: Adding Docker 21.03 (9b326e8)
  • update truncate long/double warning message (60dba12)
  • //docker: Update CI container (df63467)
  • //py: Allowing people using the PyTorch backend to use TRTorch/TRT (6c3e0ad)
  • //py: Catch when bazel is not in path and error out when running (1da999d)
  • //py: Gate partial compilation from to_backend API (bf1b2d8)
  • //py: New API to embed engine in new module (88d07a9)
  • aten::floor: Adds floor.int evaluator (a6a46e5)

BREAKING CHANGES

  • PyTorch version has been bumped to 1.8.0
    Default CUDA version is CUDA 11.1
    TensorRT version is TensorRT 7.2.3.4
    cuDNN version is now cuDNN 8.1

Signed-off-by: Naren Dasan naren@narendasan.com
Signed-off-by: Naren Dasan narens@nvidia.com

  • Due to issues with compatability between PyTorch 1.8.0
    and 1.8.1 in the Torch Python API, TRTorch 0.3.0 compiled for 1.8.0 does not
    work with PyTorch 1.8.1 and will show an error about use_input_stats.
    If you see this error make sure the version of libtorch you are
    compiling with is PyTorch 1.8.1

TRTorch 0.3.0 will target PyTorch 1.8.1. There is no backwards
compatability with 1.8.0. If you need this specific version compile from
source with the dependencies in WORKSPACE changed

Signed-off-by: Naren Dasan naren@narendasan.com
Signed-off-by: Naren Dasan narens@nvidia.com

Supported Operators in TRTorch v0.3.0

Operators Currently Supported Through Converters

  • aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> (Tensor)
  • aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> (Tensor)
  • aten::abs(Tensor self) -> (Tensor)
  • aten::acos(Tensor self) -> (Tensor)
  • aten::acosh(Tensor self) -> (Tensor)
  • aten::adaptive_avg_pool2d(Tensor self, int[2] output_size) -> (Tensor)
  • aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::add.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::asin(Tensor self) -> (Tensor)
  • aten::asinh(Tensor self) -> (Tensor)
  • aten::atan(Tensor self) -> (Tensor)
  • aten::atanh(Tensor self) -> (Tensor)
  • aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[0], bool ceil_mode=False, bool count_include_pad=True) -> (Tensor)
  • aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> (Tensor)
  • aten::batch_norm(Tensor input, Tensor? gamma, Tensor? beta, Tensor? mean, Tensor? var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor)
  • aten::cat(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::ceil(Tensor self) -> (Tensor)
  • aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> (Tensor)
  • aten::cos(Tensor self) -> (Tensor)
  • aten::cosh(Tensor self) -> (Tensor)
  • aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::div.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::div_.Scalar(Tensor(a!) self, Scalar other) -> (Tensor(a!))
  • aten::div_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> (Tensor)
  • aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)
  • aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::eq.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::erf(Tensor self) -> (Tensor)
  • aten::exp(Tensor self) -> (Tensor)
  • aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> (Tensor(a))
  • aten::expand_as(Tensor(a) self, Tensor other) -> (Tensor(a))
  • aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> (Tensor)
  • aten::floor(Tensor self) -> (Tensor)
  • aten::floor_divide(Tensor self, Tensor other) -> (Tensor)
  • aten::floor_divide.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ge.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::gt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::gt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor)
  • aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> (Tensor(a!))
  • aten::le.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::le.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)
  • aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> (Tensor(a!))
  • aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> (Tensor)
  • aten::log(Tensor self) -> (Tensor)
  • aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)
  • aten::lt.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::lt.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::matmul(Tensor self, Tensor other) -> (Tensor)
  • aten::max(Tensor self) -> (Tensor)
  • aten::max.other(Tensor self, Tensor other) -> (Tensor)
  • aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=[], int[1] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=[0, 0], int[2] dilation=[1, 1], bool ceil_mode=False) -> (Tensor)
  • aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=[], int[3] dilation=[], bool ceil_mode=False) -> (Tensor)
  • aten::mean(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::mean.dim(Tensor self, int[] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::min(Tensor self) -> (Tensor)
  • aten::min.other(Tensor self, Tensor other) -> (Tensor)
  • aten::mul.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::mul.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> (Tensor(a!))
  • aten::narrow(Tensor(a) self, int dim, int start, int length) -> (Tensor(a))
  • aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> (Tensor(a))
  • aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)
  • aten::ne.Tensor(Tensor self, Tensor other) -> (Tensor)
  • aten::neg(Tensor self) -> (Tensor)
  • aten::permute(Tensor(a) self, int[] dims) -> (Tensor(a))
  • aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)
  • aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> (Tensor)
  • aten::prelu(Tensor self, Tensor weight) -> (Tensor)
  • aten::prod(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::reciprocal(Tensor self) -> (Tensor)
  • aten::relu(Tensor input) -> (Tensor)
  • aten::relu_(Tensor(a!) self) -> (Tensor(a!))
  • aten::repeat(Tensor self, int[] repeats) -> (Tensor)
  • aten::reshape(Tensor self, int[] shape) -> (Tensor)
  • aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> (Tensor)
  • aten::rsub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))
  • aten::sigmoid(Tensor input) -> (Tensor)
  • aten::sigmoid_(Tensor(a!) self) -> (Tensor(a!))
  • aten::sin(Tensor self) -> (Tensor)
  • aten::sinh(Tensor self) -> (Tensor)
  • aten::slice.Tensor(Tensor(a) self, int dim=0, int start=0, int end=9223372036854775807, int step=1) -> (Tensor(a))
  • aten::softmax.int(Tensor self, int dim, int? dtype=None) -> (Tensor)
  • aten::split(Tensor self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::split.Tensor(Tensor(a) self, int split_size, int dim=0) -> (Tensor[])
  • aten::split_with_sizes(Tensor(a) self, int[] split_sizes, int dim=0) -> (Tensor[])
  • aten::sqrt(Tensor self) -> (Tensor)
  • aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::stack(Tensor[] tensors, int dim=0) -> (Tensor)
  • aten::sub.Tensor(Tensor self, Tensor other, Scalar alpha=1) -> (Tensor)
  • aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> (Tensor(a!))
  • aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)
  • aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)
  • aten::tan(Tensor self) -> (Tensor)
  • aten::tanh(Tensor input) -> (Tensor)
  • aten::tanh_(Tensor(a!) self) -> (Tensor(a!))
  • aten::topk(Tensor self, int k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)
  • aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> (Tensor(a))
  • aten::unsqueeze(Tensor(a) self, int dim) -> (Tensor(a))
  • aten::upsample_bilinear2d(Tensor self, int[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_bilinear2d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_linear1d(Tensor self, int[1] output_size, bool align_corners, float? scales=None) -> (Tensor)
  • aten::upsample_linear1d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest1d(Tensor self, int[1] output_size, float? scales=None) -> (Tensor)
  • aten::upsample_nearest1d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest2d(Tensor self, int[2] output_size, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest2d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_nearest3d(Tensor self, int[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_nearest3d.vec(Tensor input, int[]? output_size, float[]? scale_factors) -> (Tensor)
  • aten::upsample_trilinear3d(Tensor self, int[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> (Tensor)
  • aten::upsample_trilinear3d.vec(Tensor input, int[]? output_size, bool align_corners, float[]? scale_factors) -> (Tensor)
  • aten::view(Tensor(a) self, int[] size) -> (Tensor(a))
  • trt::const(Tensor self) -> (Tensor)

Operators Currently Supported Through Evaluators

  • aten::Bool.float(float b) -> (bool)
  • aten::Bool.int(int a) -> (bool)
  • aten::Float.Scalar(Scalar a) -> float
  • aten::Float.bool(bool a) -> float
  • aten::Float.int(int a) -> float
  • aten::and(int a, int b) -> (bool)
  • aten::getitem.t(t list, int idx) -> (t(*))
  • aten::is(t1 self, t2 obj) -> bool
  • aten::isnot(t1 self, t2 obj) -> bool
  • aten::not(bool self) -> bool
  • aten::or(int a, int b) -> (bool)
  • aten::__round_to_zero_floordiv(int a, int b) -> (int)
  • aten::xor(int a, int b) -> (bool)
  • aten::add.float(float a, float b) -> (float)
  • aten::add.int(int a, int b) -> (int)
  • aten::add_.t(t self, t[] b) -> (t[])
  • aten::append.t(t self, t(c -> *) el) -> (t)
  • aten::dim(Tensor self) -> int
  • aten::div.float(float a, float b) -> (float)
  • aten::div.int(int a, int b) -> (float)
  • aten::eq.bool(bool a, bool b) -> (bool)
  • aten::eq.float(float a, float b) -> (bool)
  • aten::eq.float_int(float a, int b) -> (bool)
  • aten::eq.int(int a, int b) -> (bool)
  • aten::eq.int_float(int a, float b) -> (bool)
  • aten::floor.float(float a) -> (int)
  • aten::floordiv.float(float a, float b) -> (int)
  • aten::floordiv.int(int a, int b) -> (int)
  • aten::ge.bool(bool a, bool b) -> (bool)
  • aten::ge.float(float a, float b) -> (bool)
  • aten::ge.float_int(float a, int b) -> (bool)
  • aten::ge.int(int a, int b) -> (bool)
  • aten::ge.int_float(int a, float b) -> (bool)
  • aten::gt.bool(bool a, bool b) -> (bool)
  • aten::gt.float(float a, float b) -> (bool)
  • aten::gt.float_int(float a, int b) -> (bool)
  • aten::gt.int(int a, int b) -> (bool)
  • aten::gt.int_float(int a, float b) -> (bool)
  • aten::le.bool(bool a, bool b) -> (bool)
  • aten::le.float(float a, float b) -> (bool)
  • aten::le.float_int(float a, int b) -> (bool)
  • aten::le.int(int a, int b) -> (bool)
  • aten::le.int_float(int a, float b) -> (bool)
  • aten::len.t(t[] a) -> (int)
  • aten::lt.bool(bool a, bool b) -> (bool)
  • aten::lt.float(float a, float b) -> (bool)
  • aten::lt.float_int(float a, int b) -> (bool)
  • aten::lt.int(int a, int b) -> (bool)
  • aten::lt.int_float(int a, float b) -> (bool)
  • aten::mul.float(float a, float b) -> (float)
  • aten::mul.int(int a, int b) -> (int)
  • aten::ne.bool(bool a, bool b) -> (bool)
  • aten::ne.float(float a, float b) -> (bool)
  • aten::ne.float_int(float a, int b) -> (bool)
  • aten::ne.int(int a, int b) -> (bool)
  • aten::ne.int_float(int a, float b) -> (bool)
  • aten::neg.int(int a) -> (int)
  • aten::numel(Tensor self) -> int
  • aten::size(Tensor self) -> (int[])
  • aten::size.int(Tensor self, int dim) -> (int)
  • aten::slice.t(t[] l, int start, int end=9223372036854775807, int step=1) -> (t[])
  • aten::sub.float(float a, float b) -> (float)
  • aten::sub.int(int a, int b) -> (int)
  • prim::max.bool(bool a, bool b) -> (bool)
  • prim::max.float(float a, float b) -> (bool)
  • prim::max.float_int(float a, int b) -> (bool)
  • prim::max.int(int a, int b) -> (bool)
  • prim::max.int_float(int a, float b) -> (bool)
  • prim::max.self_int(int[] self) -> (int)
  • prim::min.bool(bool a, bool b) -> (bool)
  • prim::min.float(float a, float b) -> (bool)
  • prim::min.float_int(float a, int b) -> (bool)
  • prim::min.int(int a, int b) -> (bool)
  • prim::min.int_float(int a, float b) -> (bool)
  • prim::min.self_int(int[] self) -> (int)
  • prim::shape(Tensor a) -> (int[])