Skip to content

Commit

Permalink
Add conv1d support in BYOC TRT by converting conv1d to conv2d (apache…
Browse files Browse the repository at this point in the history
…#9324)

Co-authored-by: ziyu.guo <ziyu.guo@bytedance.com>
  • Loading branch information
2 people authored and ylc committed Jan 7, 2022
1 parent b3e4950 commit ea5f101
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 0 deletions.
19 changes: 19 additions & 0 deletions python/tvm/relay/op/contrib/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def partition_for_tensorrt(
transform.RemoveUnusedFunctions(),
transform.ConvertLayout(
{
"nn.conv1d": ["NCW", "default"],
"nn.conv2d": ["NCHW", "default"],
"nn.conv3d": ["NCDHW", "default"],
"nn.conv2d_transpose": ["NCHW", "default"],
Expand Down Expand Up @@ -383,6 +384,23 @@ def softmax_annotate_fn(expr): # pylint: disable=unused-variable
return True


@_register_external_dynamic_check_func("nn.conv1d")
def conv1d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv1d is supported by TensorRT."""

attrs, args = expr.attrs, expr.args
if any([x.checked_type.dtype != "float32" for x in args]):
logger.info("Only float32 inputs are supported for TensorRT.")
return False
if attrs.data_layout != "NCW":
logger.info("nn.conv1d: data_layout is %s but must be NCW.", attrs.data_layout)
return False
if attrs.kernel_layout != "OIW":
logger.info("nn.conv1d: kernel_layout is %s but must be OIW.", attrs.kernel_layout)
return False
return True


@_register_external_dynamic_check_func("nn.conv2d")
def conv2d_annotate_fn(expr): # pylint: disable=unused-variable
"""Check if nn.conv2d is supported by TensorRT."""
Expand Down Expand Up @@ -921,6 +939,7 @@ def __init__(self):
def visit_call(self, call):
compute_intensive_ops = set(
[
"nn.conv1d",
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv3d",
Expand Down
50 changes: 50 additions & 0 deletions src/runtime/contrib/tensorrt/tensorrt_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,55 @@ class ElementWiseBinaryOpConverter : public TensorRTOpConverter {
}
};

class Conv1DOpConverter : public TensorRTOpConverter {
public:
Conv1DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}

void Convert(TensorRTOpConverterParams* params) const {
auto input_tensor = params->inputs.at(0).tensor;
auto input_dims = TrtDimsToVector(input_tensor->getDimensions());
auto weight_shape = params->inputs.at(1).weight_shape;
ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCW");
ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIW");
auto str_strides = params->node.GetAttr<std::vector<std::string>>("strides");
auto str_dilation = params->node.GetAttr<std::vector<std::string>>("dilation");
auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
int channels = weight_shape[0];
if (params->node.HasAttr("channels") &&
!params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
}

auto shuffle_layer = params->network->addShuffle(*input_tensor);
std::vector<int> new_shape = {input_dims[0], input_dims[1], 1};
shuffle_layer->setReshapeDimensions(VectorToTrtDims(new_shape));
input_tensor = shuffle_layer->getOutput(0);

const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], 1);
nvinfer1::Weights bias{nvinfer1::DataType::kFLOAT, nullptr, 0};

auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size,
params->inputs.at(1).weight, bias);
ICHECK(conv_layer != nullptr);
conv_layer->setPadding(nvinfer1::DimsHW(std::stoi(str_padding[0]), 0));
ICHECK_EQ(str_strides.size(), 1);
const auto strides = nvinfer1::DimsHW(std::stoi(str_strides[0]), 1);
conv_layer->setStride(strides);
ICHECK_EQ(str_dilation.size(), 1);
const auto dilation = nvinfer1::DimsHW(std::stoi(str_dilation[0]), 1);
conv_layer->setDilation(dilation);
conv_layer->setNbGroups(groups);
input_tensor = conv_layer->getOutput(0);

auto conv_output_dims = TrtDimsToVector(input_tensor->getDimensions());
std::vector<int> back_shape = {0, 0};
auto shuffle_back_layer = params->network->addShuffle(*input_tensor);
shuffle_back_layer->setReshapeDimensions(VectorToTrtDims(back_shape));
params->outputs.push_back(shuffle_back_layer->getOutput(0));
}
};

class Conv2DOpConverter : public TensorRTOpConverter {
public:
Conv2DOpConverter() : TensorRTOpConverter({kTensor, kWeight}) {}
Expand Down Expand Up @@ -1198,6 +1247,7 @@ GetOpConverters() {
map->emplace("nn.batch_norm", std::make_shared<BatchNormOpConverter>());
map->emplace("nn.layer_norm", std::make_shared<LayerNormOpConverter>());
map->emplace("nn.softmax", std::make_shared<SoftmaxOpConverter>());
map->emplace("nn.conv1d", std::make_shared<Conv1DOpConverter>());
map->emplace("nn.conv2d", std::make_shared<Conv2DOpConverter>());
map->emplace("nn.dense", std::make_shared<DenseOpConverter>());
map->emplace("nn.bias_add", std::make_shared<BiasAddOpConverter>());
Expand Down
28 changes: 28 additions & 0 deletions tests/python/contrib/test_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,34 @@ def load_vm():
assert_result_dict_holds(result_dict)


def test_conv1d(run_module):
def get_graph(
x_shape=((1, 3, 224)),
k_shape=(10, 3, 3),
groups=1,
padding=(1, 1),
strides=(1),
dilation=(1),
channels=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
out = relay.nn.conv1d(
x,
kernel,
kernel_size=k_shape[2:3],
groups=groups,
padding=padding,
strides=strides,
dilation=dilation,
channels=channels,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]

run_and_verify_func(get_graph(channels=10), run_module=run_module)


def test_conv2d(run_module):
def get_graph(
x_shape=(1, 32, 8, 8),
Expand Down

0 comments on commit ea5f101

Please sign in to comment.