Skip to content

Commit 753341b

Browse files
Merge pull request #3 from ActiveIntelligentSystemsLab/update-class-design
Update class design
2 parents 1b14c50 + d4ac5f2 commit 753341b

17 files changed

+747
-116
lines changed

CMakeLists.txt

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cmake_minimum_required(VERSION 2.8.3)
2-
project(pytorch_enet_ros)
2+
project(pytorch_ros)
33

44
add_compile_options(-std=c++14)
55

@@ -17,6 +17,8 @@ find_package(catkin REQUIRED COMPONENTS
1717
image_transport
1818
cv_bridge
1919
semantic_segmentation_srvs
20+
tf2
21+
tf2_ros
2022
)
2123

2224
find_package(Torch REQUIRED)
@@ -25,9 +27,7 @@ find_package(OpenCV REQUIRED)
2527

2628
catkin_package(
2729
INCLUDE_DIRS include
28-
# LIBRARIES pytorch_enet_ros
2930
CATKIN_DEPENDS roscpp rospy std_msgs
30-
# DEPENDS system_lib
3131
)
3232

3333
###########
@@ -44,8 +44,10 @@ include_directories(
4444

4545
## Declare a C++ library
4646
add_library(${PROJECT_NAME}
47-
src/pytorch_enet_ros.cpp
48-
src/pytorch_cpp_wrapper.cpp
47+
src/pytorch_seg_trav_path_ros.cpp
48+
src/pytorch_cpp_wrapper_seg_trav.cpp
49+
src/pytorch_cpp_wrapper_seg_trav_path.cpp
50+
src/pytorch_cpp_wrapper_base.cpp
4951
)
5052

5153
target_link_libraries(${PROJECT_NAME}
@@ -59,10 +61,10 @@ set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)
5961
## Declare a C++ executable
6062
## With catkin_make all packages are built within a single CMake context
6163
## The recommended prefix ensures that target names across packages don't collide
62-
add_executable(${PROJECT_NAME}_node src/pytorch_enet_ros_node.cpp)
64+
add_executable(pytorch_seg_trav_path_node src/pytorch_seg_trav_path_node.cpp)
6365

6466
## Specify libraries to link a library or executable target against
65-
target_link_libraries(${PROJECT_NAME}_node
67+
target_link_libraries(pytorch_seg_trav_path_node
6668
${catkin_LIBRARIES}
6769
${PROJECT_NAME}
6870
${TORCH_LIBRARIES}

README.md

Lines changed: 71 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,72 @@
1-
# pytorch_enet_ros
1+
# pytorch_ros
22

3-
A Docker environment for both training network on PyTorch, and inference using ROS is [here](https://github.com/ActiveIntelligentSystemsLab/pytorch-enet-docker).
4-
This package is only tested on the virtual environment.
3+
## 1. Overview
4+
5+
A ROS package to use [LibTorch](https://pytorch.org/cppdocs/), a PyTorch C++ API, for inference on a trained model.
6+
7+
A Docker environment for running this package is [here](https://github.com/ActiveIntelligentSystemsLab/pytorch-enet-docker).
8+
This package is **only tested in the virtual environment**.
9+
10+
## 2. Nodes
11+
12+
### 2.1 `pytorch_seg_trav_path_node`
13+
14+
#### **2.1.1 Subscribed topics**
15+
16+
- `image` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html))
17+
18+
An input image
19+
20+
#### **2.1.2 Published topics**
21+
22+
- `label` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html))
23+
24+
Image that stores label indices of each pixel
25+
26+
- `color_label` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html))
27+
28+
Image that stores color labels of each pixel (for visualization)
29+
30+
- `prob` ([sensor_msgs/Image](http://docs.ros.org/melodic/api/sensor_msgs/html/msg/Image.html))
31+
32+
Image that stores *traversability* of each pixel
33+
34+
- `start_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html))
35+
36+
Start point of the estimated path line
37+
38+
- `end_point` ([geometry_msgs/PointStamped](http://docs.ros.org/en/melodic/api/geometry_msgs/html/msg/PointStamped.html))
39+
40+
End point of the estimated path line
41+
42+
#### **2.1.3 Service**
43+
44+
- `get_label_image` ([semantic_segmentation_srvs/GetLabelAndProbability](https://github.com/ActiveIntelligentSystemsLab/aisl_utils/blob/master/aisl_srvs/semantic_segmentation_srv/srv/GetLabelAndProbability.srv))
45+
46+
Return inference results (segmentation and traversability) for a given image.
47+
48+
## 3. How to run the node
49+
50+
```
51+
roslaunch pytorch_enet_ros.launch image:=<image topic name> model_name:=<model name>
52+
```
53+
54+
## 4. Weight files
55+
56+
The ROS nodes in this package use models saved as a serialized Torch Script file.
57+
58+
At this moment, we don't provide a script to generate the weight files.
59+
60+
Refer to [this page](https://pytorch.org/tutorials/advanced/cpp_export.html) to get the weight file.
61+
62+
### CAUTION
63+
If the version of PyTorch that runs this ROS package and that you generate your weight file (serialized Torch Script) do not match, the ROS node may fail to import the weights.
64+
65+
For example, if you use [our Docker environment](https://github.com/ActiveIntelligentSystemsLab/pytorch-enet-docker), the weights should be generated using PyTorch 1.5.0.
66+
67+
## 5. Color map
68+
69+
For visualization of semantic segmentation, we use a color map image.
70+
71+
It is a 1xC PNG image file (C: The number of classes), where
72+
the color of class i is stored in the pixel at (1, i).

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper.h

Lines changed: 0 additions & 32 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#ifndef PYTORCH_CPP_WRAPPER_BASE
2+
#define PYTORCH_CPP_WRAPPER_BASE
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
9+
#include <iostream>
10+
#include <memory>
11+
12+
/**
13+
* @brief this class is a base class of C++ wrapper of PyTorch
14+
*/
15+
class PyTorchCppWrapperBase {
16+
protected :
17+
torch::jit::script::Module module_;
18+
19+
public:
20+
PyTorchCppWrapperBase();
21+
PyTorchCppWrapperBase(const std::string & filename);
22+
PyTorchCppWrapperBase(const char* filename);
23+
24+
/**
25+
* @brief import a network
26+
* @param filename
27+
* @return true if import succeeded
28+
*/
29+
bool import_module(const std::string & filename);
30+
31+
/**
32+
* @brief convert an image(cv::Mat) to a tensor (at::Tensor)
33+
* @param[in] img
34+
* @param[out] tensor
35+
* @param[in] whether to use GPU
36+
*/
37+
void img2tensor(cv::Mat & img, at::Tensor & tensor, const bool & use_gpu = true);
38+
39+
/**
40+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
41+
* @param[in] tensor
42+
* @param[out] img
43+
*/
44+
void tensor2img(at::Tensor tensor, cv::Mat & img);
45+
46+
/**
47+
* @brief Take element-wise argmax
48+
* @param[in] tensor
49+
* @param[out] tensor that has index of max value in each element
50+
*/
51+
at::Tensor get_argmax(at::Tensor input_tensor);
52+
};
53+
//}
54+
#endif
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#ifndef PYTORCH_CPP_WRAPPER
2+
#define PYTORCH_CPP_WRAPPER
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
#include "opencv2/highgui/highgui.hpp"
9+
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
10+
11+
#include <iostream>
12+
#include <memory>
13+
14+
class PyTorchCppWrapperSegTrav : public PyTorchCppWrapperBase {
15+
private :
16+
// c = P(s|y=1) in PU learning, calculated during training
17+
float c_{0.3};
18+
19+
public:
20+
/**
21+
* @brief Get outputs from the model
22+
* @param[in] input_tensor Input tensor
23+
* @return A tuple of output tensors (segmentation and traversability)
24+
*/
25+
std::tuple<at::Tensor, at::Tensor> get_output(at::Tensor input_tensor);
26+
};
27+
#endif
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef PYTORCH_CPP_WRAPPER
2+
#define PYTORCH_CPP_WRAPPER
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
#include "opencv2/highgui/highgui.hpp"
9+
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
10+
11+
#include <iostream>
12+
#include <memory>
13+
14+
15+
class PyTorchCppWrapperSegTravPath : public PyTorchCppWrapperBase {
16+
private :
17+
// c = P(s|y=1) in PU learning, calculated during training
18+
float c_{0.3};
19+
20+
public:
21+
/**
22+
* @brief Get outputs from the model
23+
* @param[in] input_tensor Input tensor
24+
* @return A tuple of output tensors (segmentation, traversability, and path (points))
25+
*/
26+
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_output(at::Tensor input_tensor);
27+
};
28+
#endif

include/pytorch_enet_ros/pytorch_enet_ros.h renamed to include/pytorch_ros/pytorch_enet_ros.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
#include <opencv2/opencv.hpp>
1313
#include<image_transport/image_transport.h>
1414
#include<cv_bridge/cv_bridge.h>
15-
#include<semantic_segmentation_srvs/GetLabelImage.h>
15+
//#include<semantic_segmentation_srvs/GetLabelImage.h>
16+
#include<semantic_segmentation_srvs/GetLabelAndProbability.h>
1617

17-
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
18+
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h"
1819

1920
#include <iostream>
2021
#include <memory>
@@ -31,18 +32,19 @@ class PyTorchENetROS {
3132
image_transport::Subscriber sub_image_;
3233
image_transport::Publisher pub_label_image_;
3334
image_transport::Publisher pub_color_image_;
35+
image_transport::Publisher pub_prob_image_;
3436

35-
PyTorchCppWrapper pt_wrapper_;
37+
PyTorchCppWrapperSegTrav pt_wrapper_;
3638

3739
cv::Mat colormap_;
3840

3941
public:
4042
PyTorchENetROS(ros::NodeHandle & nh);
4143

4244
void image_callback(const sensor_msgs::ImageConstPtr& msg);
43-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
44-
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelImage::Request & req,
45-
semantic_segmentation_srvs::GetLabelImage::Response & res);
45+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
46+
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req,
47+
semantic_segmentation_srvs::GetLabelAndProbability::Response & res);
4648
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg);
4749
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg);
4850
void label_to_color(cv::Mat& label, cv::Mat& color_label);

0 commit comments

Comments
 (0)