Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is it possible to deploy spconv with tensorRT or any other methods? #33

Closed
bigsheep2012 opened this issue Apr 1, 2019 · 32 comments
Closed
Labels

Comments

@bigsheep2012
Copy link

Hello Yan @traveller59 ,
I am trying to deploy models with spconv layers (tensorRT) and found I may need to dig into pytorch source codes. I am wondering if there is any easier way for deploying models with spconv such that I could directly utilize c++ for fast inference?

Thanks in advance.

@traveller59
Copy link
Owner

traveller59 commented Apr 1, 2019

What kind of method/api do you want for fast inference? TensorRT? torch.jit? If you use TensorRT, there is no need to dig into pytorch code.

  1. you need to modify the spconv code to support fused sparseconv-batchnorm-relu operation and fixed-shape operation, these are most important.
  2. you need to write TensorRT custom plugin, it's not hard.

I will take a look on torch.jit, fixed-shape and fused operation for spconv, but you need to implement tensorrt plugin by yourself.

@muzi2045
Copy link

muzi2045 commented May 5, 2019

I have tried the pointpillars TensorRT code in the Autoware repo, but the speed can't promotion too much(45ms ~ 50ms in 1050TI) when inference.
write the TensorRT custom plugin means I need to add custom layers in TensorRT?
@traveller59

@traveller59
Copy link
Owner

@muzi2045
pointpillars don't contain sparse convolutions so you don't need to write custom plugins (you may need to write a fused post process layer).
45-50ms seems good for 1050ti, pointpillars runs with 16ms in 1080TI and TensorRT, and 1080Ti is almost 3x faster than 1050Ti

@muzi2045
Copy link

muzi2045 commented May 8, 2019

thanks for your reply.
there is another thing about pointpillars pytorch convert to onnx model, the onnx model generate from second repo looks like really different from Autoware TensorRT code(https://github.com/k0suke-murakami/kitti_pretrained_point_pillars).

the first model is convert from second repo (pp_nu_config)(pfe layer)
pfe

the second model are download from (https://github.com/k0suke-murakami/kitti_pretrained_point_pillars)
pfe

hope for any advice for it, I just want to using the onnx model trained myself with tensorrt.

@traveller59

@traveller59
Copy link
Owner

traveller59 commented May 8, 2019

@muzi2045 They use a special module to fit their c++ code. You can just write a module based on this graph.
here is an example (not tested), be careful with shapes:

    def forward(self, pillar_x, pillar_y, pillar_z, pillar_i, x_sub_shaped, y_sub_shaped, num_points_per_pillar, mask):
        device = features.device
        dtype = features.dtype
        # Find distance of x, y, and z from cluster center
        pillars = torch.cat([pillar_x, pillar_y, pillar_z], dim=1)
        pillars_features = torch.cat([pillar_x, pillar_y, pillar_z, pillar_i], dim=1)
        points_mean = pillars.sum(dim=3, keepdim=True) / num_voxels.type_as(features).view(1, 1, -1, 1)
        f_cluster = pillars - points_mean
        # Find distance of x, y, and z from pillar center
        f_center_x = pillar_x - x_sub_shaped
        f_center_y = pillar_y - y_sub_shaped
        f_center = torch.cat([f_center_x, f_center_y], dim=1)
        # Combine together feature decorations
        features_ls = [pillars_features, f_cluster, f_center]
        features = torch.cat(features_ls, dim=1)
        features *= mask
        # Forward pass through PFNLayers
        for pfn in self.pfn_layers:
            features = pfn(features)
        return features

@spchuang
Copy link

@travellor59, thanks for the code pointer! That was very helpful.

I was able to export PFE and RPN to onnx with the exact architecture as autoware implementation. However, importing the model in autoware and running point pillar in tensorRT did not yield the right result. I suspect their preprocessing or post processing might be different from your SECOND repo. Do you happen to know if that is the case? It's strange since inspecting the autoware's onnx models's stack traces reveal that they should use similar SECOND repo as well..

@traveller59
Copy link
Owner

@spchuang

class PFNLayerTensorRT(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 use_norm=True,
                 last_layer=False):
        super().__init__()
        self.name = 'PFNLayerTensorRT'
        self.last_vfe = last_layer
        assert self.last_vfe is True, "tensor rt don't support this."
        if not self.last_vfe:
            out_channels = out_channels // 2
        self.units = out_channels

        if use_norm:
            BatchNorm2d = change_default_args(
                eps=1e-3, momentum=0.01)(nn.BatchNorm2d)
            Conv2d = change_default_args(bias=False)(nn.Conv2d)
        else:
            BatchNorm2d = Empty
            Conv2d = change_default_args(bias=True)(nn.Conv2d)

        self.linear = Conv2d(in_channels, self.units, 1)
        self.norm = BatchNorm2d(self.units)

    def forward(self, inputs):
        """inputs: [1, num_features, max_num_points, max_points_per_voxel]
        [1, 64, 12000, 100]
        """
        x = self.linear(inputs)
        x = self.norm(x)
        x = F.relu(x)
        # x: [N, C, numPoints, numPointPerVoxel]
        x_max = torch.max(x, dim=3, keepdim=True)[0]
        if self.last_vfe:
            return x_max
        else:
            # may need to use conv/broadcast/gather to implement repeat in tensorrt.
            x_repeat = x_max.repeat(1, inputs.shape[1], 1)
            x_concatenated = torch.cat([x, x_repeat], dim=2)
            return x_concatenated

@register_vfe
class PillarFeatureNetTensorRT(nn.Module):
    def __init__(self,
                 num_input_features=4,
                 use_norm=True,
                 num_filters=(64, ),
                 with_distance=False,
                 voxel_size=(0.2, 0.2, 4),
                 pc_range=(0, -40, -3, 70.4, 40, 1)):
        super().__init__()
        self.name = 'PillarFeatureNetTensorRT'
        assert len(num_filters) > 0
        num_input_features += 5
        if with_distance:
            num_input_features += 1
        self._with_distance = with_distance
        assert with_distance is False
        # Create PillarFeatureNetOld layers
        num_filters = [num_input_features] + list(num_filters)
        pfn_layers = []
        assert len(num_filters) == 2, "tensorrt don't support repeat"
        for i in range(len(num_filters) - 1):
            in_filters = num_filters[i]
            out_filters = num_filters[i + 1]
            if i < len(num_filters) - 2:
                last_layer = False
            else:
                last_layer = True
            pfn_layers.append(
                PFNLayerTensorRT(
                    in_filters, out_filters, use_norm, last_layer=last_layer))
        self.pfn_layers = nn.ModuleList(pfn_layers)

        # Need pillar (voxel) size and x/y offset in order to calculate pillar offset
        self.vx = voxel_size[0]
        self.vy = voxel_size[1]
        self.x_offset = self.vx / 2 + pc_range[0]
        self.y_offset = self.vy / 2 + pc_range[1]

    def forward(self, features, num_voxels, coors, voxel_point_mask):
        """TensorRT PFE must use specified inputs. All tensorrt inputs must be float for now.
        features: [1, num_point_features, max_num_points, max_points_per_voxel]
        num_voxels: [1, 1, max_num_points, 1]
        coors: [1, 4, max_num_points, 1]
        voxel_point_mask: [1, 1, max_num_points, max_points_per_voxel]
        """
        device = features.device
        dtype = features.dtype
        # Find distance of x, y, and z from cluster center
        points_mean = features[:, :3].sum(
            dim=3, keepdim=True) / num_voxels.view(1, 1, -1, 1)
        f_cluster = features[:, :3] - points_mean
        
        # Find distance of x, y, and z from pillar center
        f_center_x = features[:, 0:1] - (coors[:, 3:4] * self.vx + self.x_offset)
        f_center_y = features[:, 1:2] - (coors[:, 2:3] * self.vy + self.y_offset)
        # Combine together feature decorations
        features_ls = [features, f_cluster, f_center_x, f_center_y]
        features = torch.cat(features_ls, dim=1)
        
        # The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
        # empty pillars remain set to zeros.
        features *= voxel_point_mask
        # Forward pass through PFNLayers
        for pfn in self.pfn_layers:
            features = pfn(features)
        # features shape: [1, 64, 12000, 1]
        return features # .t()

Codes above is tested in tensorrt (by torch2trt) and can generate same results as pytorch model.

I don't know why Autoware implementation contains two convs and have no reduce max operator. note that reduce max is slow in tensorrt.

@spchuang
Copy link

spchuang commented May 15, 2019

Hi thanks for the quick response! I came up with something similiar:

class VoxelNet(nn.Module):
  ...
  def forward(self, example, export_onnx=False):
        # tweak original input to fit autoware model format. First pad to max_num_voxels for tensorRT input
        num_pillars = voxels.shape[0]
        max_num_pillars = self.voxel_generator._max_voxels # 120000
        max_num_points = self.voxel_generator._max_num_points

        pad_voxels = torch.zeros(max_num_pillars, max_num_points, 4, dtype=dtype, device=device)
        print(pad_voxels.size(), voxels.size())
        pad_voxels[:num_pillars, :, :] = voxels
        voxels = pad_voxels

        pad_num_points = torch.zeros(max_num_pillars, dtype=dtype, device=device)
        pad_num_points[:num_pillars] = num_points
        num_points = pad_num_points

        pad_coors = torch.zeros(max_num_pillars, 4, dtype=dtype, device=device)
        pad_coors[:num_pillars, :] = coors
        coors = pad_coors
        # now voxels shoudl be [12000, 100, 4], num_points is [120000], and coors [12000, 4]

        data = self._prepareAutowareInput(self.voxel_feature_extractor, voxels, num_points, coors)
        voxel_features = torch.from_numpy(prepared_backend.run(W)[0]).float().to(data[0].device)
        voxel_features = voxel_features.view(voxel_features.shape[2], voxel_features.shape[1])
 
        ...
        # keeping rest the same

  def _prepareAutowareInput(self, voxel_extractor, voxels, num_points, coors):
        dtype = voxels.dtype 

        pillar_x = voxels[:,:,0].unsqueeze(0).unsqueeze(0)
        pillar_y = voxels[:,:,1].unsqueeze(0).unsqueeze(0)
        pillar_z = voxels[:,:,2].unsqueeze(0).unsqueeze(0)
        pillar_i = voxels[:,:,3].unsqueeze(0).unsqueeze(0)

        num_points_per_pillars = num_points.unsqueeze(0).type_as(voxels)

        x_sub_shaped = coors[:, 3].to(dtype).unsqueeze(1) * float(voxel_extractor.vx) + voxel_extractor.x_offset
        y_sub_shaped = coors[:, 2].to(dtype).unsqueeze(1) * float(voxel_extractor.vy) + voxel_extractor.y_offset

        x_sub_shaped = x_sub_shaped.repeat(1, voxels.shape[1]).unsqueeze(0).unsqueeze(0)
        y_sub_shaped = y_sub_shaped.repeat(1, voxels.shape[1]).unsqueeze(0).unsqueeze(0)

        mask = get_paddings_indicator(num_points, voxels.shape[1], axis=0).type_as(voxels)
        mask = mask.unsqueeze(0).unsqueeze(0)

        # pillar_x,y,z,i is [1, 1, 12000, 100]
        # num_points_per_pillars is [1, 12000]
        # x,y_sub_shape is [1, 1, 12000, 100]
        # mask is [1, 1, 12000, 100]
        return (pillar_x, pillar_y, pillar_z, pillar_i, num_points_per_pillars, x_sub_shaped, y_sub_shaped, mask)

// PFE implentation

# based on model arch from https://github.com/k0suke-murakami/kitti_pretrained_point_pillars
class PFNLayer2(nn.Module):
    def __init__(self,):
        super().__init__()
        self.name = 'PFNLayer'
        
        self.conv1 = nn.Conv2d(9, 64, kernel_size=(1,1), stride=(1, 1), padding=(0, 0), dilation=(1, 1))
        self.norm = change_default_args(eps=1e-3, momentum=0.01)(nn.BatchNorm2d)(64)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=(1, 34), stride=(1, 1), padding=(0, 0), dilation=(1, 3))

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.norm(x)
        x = F.relu(x)
        x = self.conv2(x)
        return x

class PillarFeatureNet(nn.Module):
    def __init__(self,
                 num_input_features=4,
                 use_norm=True,
                 num_filters=(64, ),
                 with_distance=False,
                 voxel_size=(0.2, 0.2, 4),
                 pc_range=(0, -40, -3, 70.4, 40, 1)):

        super().__init__()
        self.name = 'PillarFeatureNet'
        assert len(num_filters) > 0
        num_input_features += 5
        if with_distance:
            num_input_features += 1
        self._with_distance = with_distance

        # Create PillarFeatureNet layers
        num_filters = [num_input_features] + list(num_filters)
        pfn_layers = []
        for i in range(len(num_filters) - 1):
            in_filters = num_filters[i]
            out_filters = num_filters[i + 1]
            if i < len(num_filters) - 2:
                last_layer = False
            else:
                last_layer = True

            pfn_layers.append(PFNLayer2())
        self.pfn_layers = nn.ModuleList(pfn_layers)

        # Need pillar (voxel) size and x/y offset in order to calculate pillar offset
        self.vx = voxel_size[0]
        self.vy = voxel_size[1]
        self.x_offset = self.vx / 2 + pc_range[0]
        self.y_offset = self.vy / 2 + pc_range[1]

    # Modify to fit autoware pointpillar's input format
    def forward(self, pillar_x, pillar_y, pillar_z, pillar_i, num_points_per_pillars, x_sub_shaped, y_sub_shaped, mask):
        device = pillar_x.device
        dtype = pillar_x.dtype
        
        pillars = torch.cat([pillar_x, pillar_y, pillar_z], dim=1)
        pillars_features = torch.cat([pillar_x, pillar_y, pillar_z, pillar_i], dim=1)
        points_mean = pillars.sum(dim=3, keepdim=True) / num_points_per_pillars.type_as(pillar_x).view(1, 1, -1, 1)
        f_cluster = pillars - points_mean
        # Find distance of x, y, and z from pillar center
        f_center_x = pillar_x - x_sub_shaped
        f_center_y = pillar_y - y_sub_shaped
        f_center = torch.cat([f_center_x, f_center_y], dim=1)

        # Combine together feature decorations
        features_ls = [pillars_features, f_cluster, f_center]
        features = torch.cat(features_ls, dim=1)
        features *= mask

        # Forward pass through PFNLayers
        for pfn in self.pfn_layers:
            features = pfn(features)

        return features

// export via torch.onnx

            input_data = (data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7])
            torch_out = torch.onnx._export( self.voxel_feature_extractor,             # model being run
                                input_data,                       # model input (or a tuple for multiple inputs)
                                PFE_FILE, # where to save the model (can be a file or file-like object)
                                input_names=['pillar_x', 'pillar_y', 'pillar_z', 'pillar_i', 'num_points_per_pillar', 'x_sub_shaped', 'y_sub_shaped', 'mask'],
                                export_params=True,
                                verbose=True)      # store the trained parameter weights inside the model file

I was able to retrain SECOND and the results look reasonable when running in inference with this repo. What I'm struggling to get working is using the ONNX models exported from this to autoware and use their tensorRT implementation. I'm not sure if there's any other discrepancies with any pre/post processing of the 2 repos. I know it is outside the scope of this repo, but will be great if you can shed some light!

@traveller59
Copy link
Owner

@spchuang are you using multi-class model? recently I change the layout of anchors and outputs (don't affect single class detection).

@spchuang
Copy link

spchuang commented May 15, 2019 via email

@muzi2045
Copy link

the Autoware pointpillars repo author released the train code:
https://github.com/k0suke-murakami/train_point_pillars

@muzi2045
Copy link

muzi2045 commented May 17, 2019

how to using torch2trt in the second repo?
I have modified the pointpillars layer code, and need I retrained the whole model?
And export the tensorrt engine in forward function same as onnx model export for once, after that load the .engine file to forward network.

here is the time cost in my test code without tensorrt in 1050ti.
it looks like RPN need speed up with tensorrt and voxel_generater need to promotion.

prepare cloud time:  0.0009071826934814453
 input points shape: (27870, 4)
 voxel_generater cost time:  0.006397724151611328
 voxels convert to tensor time:  0.0029985904693603516
   voxel_feature_extractor cost time: 0.008017539978027344
   middle_feature_extractor cost time: 0.001955270767211914
   rpn cost time: 0.01836371421813965
   predict cost time: 0.009577035903930664
 network predict time cost: 0.04015803337097168
 remove_low_score cost: 0.0017924308776855469
 return result time: 0.00035858154296875
 total_run_time:  0.05408644676208496
network forward time:  0.054303884506225586
publish time cost:  0.0005733966827392578
total callback time:  0.05652928352355957

@traveller59

@traveller59
Copy link
Owner

@muzi2045 you don't need to retrain model (except you want to use code in autoware), just reshape weights.
best way is write a plugin for pillar scatter and convert the whole second network (except postprocess, should be implemented in a fused plugin (python inference) or c++ code) to engine. You need to load plugin library before load the engine file.

@traveller59
Copy link
Owner

@spchuang Are you using newest code of SECOND? you may need to change the final permute operations to commented version.

@muzi2045
Copy link

you mean that just convert whole second network to tensorrt?
And write a plugin same like spconv to convert whole network to tensorrt and speed it up.
@traveller59

@traveller59
Copy link
Owner

@muzi2045

  • network speed:
    the preprocess time isn't important for now, we should focus on network first.
    I have test the tensorrt code. RPN can get 7.3ms in 1080Ti in tensorrt (same as paper), but PillarFeatureNet is very slow (1.3ms in paper, 4.0ms in my test, 13000 voxels), I think it's impossible to reach 1.3ms unless modify network architecture because the max operation in PFNLayer is slow.
    postprocess speed can be improved by writing fused postprocess layer.

  • convert whole network to tensorrt:
    if you use autoware code, don't need to do this. If you want to write custom codebase and inference in python, you need to write tensorrt plugin and use single engine.

@muzi2045
Copy link

muzi2045 commented May 17, 2019

thanks for your advice !

I'll try to learn how to write a tensorrt plugin
@traveller59

@spchuang
Copy link

@spchuang Are you using newest code of SECOND? you may need to change the final permute operations to commented version.

I was originally training with fairly recently code of SECOND

@muzi2045
Copy link

I am trying to simply transform torch to tensorrt model in voxelnet.py , but it can't work.
I am looking for some tutorial in tensorrt convert , hopefully for any advice.

voxel_features_extractor_graph = torch2trt.GraphModule(self.voxel_feature_extractor, 
                                                                voxels,
                                                                num_points,
                                                                coors,
                                                                param_exclude=".*AuxLogits.*")
with trt.Builder(TRT_LOGGER) as builder, builder.create_network() as trt_net:
       builder.max_workspace_size = 1 << 30
       with torch2trt.trt_network(trt_net):
             input_1 = trt_net.add_input(name="voxels", shape=voxels.shape, dtype=trt.float32)
             input_2 = trt_net.add_input(name="num_points", shape=num_points.shape, dtype=trt.float32)
             input_3 = trt_net.add_input(name="coors", shape=coors.shape, dtype=trt.float32)
             trt_mode_out = voxel_features_extractor_graph(input_1, input_2, input_3, verbose=True)
        trt_mode_out.name = "output_feature_extractor"
        trt_net.mark_output(tensor=trt_mode_out)
        engine = builder.build_cuda_engine(trt_net)
        engine_bin = engine.serialize()
        with open("feature_extractor.engine", "wb") as f:
              f.write(engine_bin)
        voxel_features = self.voxel_feature_extractor(voxels, num_points,
                                                      coors)

@traveller59

@RuiAidriver
Copy link

    # now voxels shoudl be [12000, 100, 4], num_points is [120000], and coors [12000, 4]

    data = self._prepareAutowareInput(self.voxel_feature_extractor, voxels, num_points, coors)
    voxel_features = torch.from_numpy(prepared_backend.run(W)[0]).float().to(data[0].device)
    voxel_features = voxel_features.view(voxel_features.shape[2], voxel_features.shape[1])

    ...
    # keeping rest the same

Hi, @spchuang
I'm working on generate PFE file for autoware too.
Could you confirm if you are using code like below for W and prepared_backend.run(W)[0].

prepared_backend = onnx_caffe2.backend.prepare(model)
W = {model.graph.input[0].name: model_inputs.numpy()}

Here I'm not sure what is model and model_inputs.
I've tried to use 'pfe.onnx', 'dummy.onnx' and self.voxel_feature_extractor as model; and voxels, data and data[0] as model_inputs.
But none of them is working.
Would you mind tell me what you used as model and model_inputs?

Many thanks!

@SmallMunich
Copy link

@muzi2045 Hi, I try to use autoware framework of pointpillars algorithms. I want to known how to export pfe.onnx & rpn.onnx success?

@SmallMunich
Copy link

@muzi2045 Hi, I try to use autoware framework of pointpillars algorithms. I want to known how to export pfe.onnx & rpn.onnx success?

I fix some bugs, model convert onnx success.

@muzi2045
Copy link

Can you normally get the right result?

@SmallMunich
Copy link

Can you normally get the right result?

yes, I check this onnx export model, get this precise about 0.001 errors.

@riyadshairi979
Copy link
Contributor

riyadshairi979 commented Apr 30, 2020

@travellor59, thanks for the code pointer! That was very helpful.

I was able to export PFE and RPN to onnx with the exact architecture as autoware implementation. However, importing the model in autoware and running point pillar in tensorRT did not yield the right result. I suspect their preprocessing or post processing might be different from your SECOND repo. Do you happen to know if that is the case? It's strange since inspecting the autoware's onnx models's stack traces reveal that they should use similar SECOND repo as well..

Can you please share some code? How did you convert to onnx? I am getting errors like below when tried to convert this: https://github.com/traveller59/spconv/blob/master/test/fake_train.py#L51

RuntimeError: Found an unsupported argument type c10::List<at::Tensor> in the JIT tracer. File a bug report. (addOutput at /pytorch/torch/csrc/jit/tracer.h:337)

@tzhong518
Copy link

@travellor59, thanks for the code pointer! That was very helpful.
I was able to export PFE and RPN to onnx with the exact architecture as autoware implementation. However, importing the model in autoware and running point pillar in tensorRT did not yield the right result. I suspect their preprocessing or post processing might be different from your SECOND repo. Do you happen to know if that is the case? It's strange since inspecting the autoware's onnx models's stack traces reveal that they should use similar SECOND repo as well..

Can you please share some code? How did you convert to onnx? I am getting errors like below when tried to convert this: https://github.com/traveller59/spconv/blob/master/test/fake_train.py#L51

RuntimeError: Found an unsupported argument type c10::List<at::Tensor> in the JIT tracer. File a bug report. (addOutput at /pytorch/torch/csrc/jit/tracer.h:337)

Hi, I met the same problem when trying to convert to onnx.
Did you solve it?
Thank you very much.

@riyadshairi979
Copy link
Contributor

@travellor59, thanks for the code pointer! That was very helpful.
I was able to export PFE and RPN to onnx with the exact architecture as autoware implementation. However, importing the model in autoware and running point pillar in tensorRT did not yield the right result. I suspect their preprocessing or post processing might be different from your SECOND repo. Do you happen to know if that is the case? It's strange since inspecting the autoware's onnx models's stack traces reveal that they should use similar SECOND repo as well..

Can you please share some code? How did you convert to onnx? I am getting errors like below when tried to convert this: https://github.com/traveller59/spconv/blob/master/test/fake_train.py#L51

RuntimeError: Found an unsupported argument type c10::List<at::Tensor> in the JIT tracer. File a bug report. (addOutput at /pytorch/torch/csrc/jit/tracer.h:337)

Hi, I met the same problem when trying to convert to onnx.
Did you solve it?
Thank you very much.

No, I didn't. spconv pytorch codes are not torchscriptable, hence it is not possible to convert to onnx.

@anyubin1001
Copy link

Hello Yan @traveller59 ,
I am trying to deploy models with spconv layers (tensorRT) and found I may need to dig into pytorch source codes. I am wondering if there is any easier way for deploying models with spconv such that I could directly utilize c++ for fast inference?

Thanks in advance.

hi,bigsheep2012
Have you implemented spconv on tensorrt, can you share the code with me

@bigsheep2018
Copy link

@anyubin1001 Hello anyubin1001, I do re-implement it with cuda and c++ (without tensorrt).
The general idea is to just implement the inference part and ignore the back propagation stuffs. The weight is loaded at the begining. As a result, I still use this repo's code for training, but use cuda for inference.

I may not be able to share the code as it is encryped by my company's software. The logic is basically the same as the SECOND paper mentioned.

@anyubin1001
Copy link

@anyubin1001 Hello anyubin1001, I do re-implement it with cuda and c++ (without tensorrt).
The general idea is to just implement the inference part and ignore the back propagation stuffs. The weight is loaded at the begining. As a result, I still use this repo's code for training, but use cuda for inference.

I may not be able to share the code as it is encryped by my company's software. The logic is basically the same as the SECOND paper mentioned.

Thank you very much for your ideas! !

@github-actions github-actions bot added the Stale label Dec 5, 2021
@github-actions github-actions bot closed this as completed Dec 5, 2021
@GeneralJing
Copy link

What kind of method/api do you want for fast inference? TensorRT? torch.jit? If you use TensorRT, there is no need to dig into pytorch code.

  1. you need to modify the spconv code to support fused sparseconv-batchnorm-relu operation and fixed-shape operation, these are most important.
  2. you need to write TensorRT custom plugin, it's not hard.

I will take a look on torch.jit, fixed-shape and fused operation for spconv, but you need to implement tensorrt plugin by yourself.

Has anyone implemented a plugin for spconv? i want to implement this plugin, but new to plugin. can anyone give some advice?

@ArseniuML
Copy link

What kind of method/api do you want for fast inference? TensorRT? torch.jit? If you use TensorRT, there is no need to dig into pytorch code.

1. you need to modify the spconv code to support fused sparseconv-batchnorm-relu operation and fixed-shape operation, these are most important.

2. you need to write TensorRT custom plugin, it's not hard.

I will take a look on torch.jit, fixed-shape and fused operation for spconv, but you need to implement tensorrt plugin by yourself.

What kind of method/api do you want for fast inference? TensorRT? torch.jit? If you use TensorRT, there is no need to dig into pytorch code.

  1. you need to modify the spconv code to support fused sparseconv-batchnorm-relu operation and fixed-shape operation, these are most important.
  2. you need to write TensorRT custom plugin, it's not hard.

I will take a look on torch.jit, fixed-shape and fused operation for spconv, but you need to implement tensorrt plugin by yourself.

Has anyone implemented a plugin for spconv? i want to implement this plugin, but new to plugin. can anyone give some advice?

https://github.com/jingyue202205/SE-SSD-AI-TRT/blob/master/sparseConv3dlayer.h

The question is - how do I can use this plugin to transform Pytorch code with sparse convolution to ONNX or directly to TensorRT?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests