Skip to content

Commit 2f89f0a

Browse files
[update] Separate the base class and child classes
1 parent 707d18e commit 2f89f0a

14 files changed

+529
-87
lines changed

CMakeLists.txt

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ include_directories(
4545
## Declare a C++ library
4646
add_library(${PROJECT_NAME}
4747
src/pytorch_enet_ros.cpp
48-
src/pytorch_cpp_wrapper.cpp
48+
src/pytorch_seg_trav_path.cpp
49+
src/pytorch_cpp_wrapper_seg_trav.cpp
50+
src/pytorch_cpp_wrapper_seg_trav_path.cpp
51+
src/pytorch_cpp_wrapper_base.cpp
4952
)
5053

5154
target_link_libraries(${PROJECT_NAME}
@@ -60,10 +63,17 @@ set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)
6063
## With catkin_make all packages are built within a single CMake context
6164
## The recommended prefix ensures that target names across packages don't collide
6265
add_executable(${PROJECT_NAME}_node src/pytorch_enet_ros_node.cpp)
66+
add_executable(pytorch_seg_trav_path_node src/pytorch_seg_trav_path_node.cpp)
6367

6468
## Specify libraries to link a library or executable target against
6569
target_link_libraries(${PROJECT_NAME}_node
6670
${catkin_LIBRARIES}
6771
${PROJECT_NAME}
6872
${TORCH_LIBRARIES}
6973
)
74+
75+
target_link_libraries(pytorch_seg_trav_path_node
76+
${catkin_LIBRARIES}
77+
${PROJECT_NAME}
78+
${TORCH_LIBRARIES}
79+
)

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper.h

Lines changed: 0 additions & 35 deletions
This file was deleted.
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#ifndef PYTORCH_CPP_WRAPPER_BASE
2+
#define PYTORCH_CPP_WRAPPER_BASE
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
9+
#include <iostream>
10+
#include <memory>
11+
12+
/**
13+
* @brief this class is a base class of C++ wrapper of PyTorch
14+
*/
15+
class PyTorchCppWrapperBase {
16+
protected :
17+
torch::jit::script::Module module_;
18+
19+
public:
20+
PyTorchCppWrapperBase();
21+
PyTorchCppWrapperBase(const std::string & filename);
22+
PyTorchCppWrapperBase(const char* filename);
23+
// virtual ~PyTorchCppWrapperBase();
24+
25+
/**
26+
* @brief import a network
27+
* @param filename
28+
* @return true if import succeeded
29+
*/
30+
bool import_module(const std::string & filename);
31+
32+
/**
33+
* @brief convert an image(cv::Mat) to a tensor (at::Tensor)
34+
* @param[in] img
35+
* @param[out] tensor
36+
* @param[in] whether to use GPU
37+
*/
38+
void img2tensor(cv::Mat & img, at::Tensor & tensor, const bool & use_gpu = true);
39+
40+
/**
41+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
42+
* @param[in] tensor
43+
* @param[out] img
44+
*/
45+
void tensor2img(at::Tensor tensor, cv::Mat & img);
46+
47+
/**
48+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
49+
* @param[in] tensor
50+
* @param[out] img
51+
*/
52+
at::Tensor get_argmax(at::Tensor input_tensor);
53+
54+
/**
55+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
56+
* @param[in] input_tensor
57+
* @param[out] Output from the network (depends on the implementation)
58+
*/
59+
// virtual auto get_output(at::Tensor & input_tensor) = 0;
60+
61+
};
62+
//}
63+
#endif
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#ifndef PYTORCH_CPP_WRAPPER
2+
#define PYTORCH_CPP_WRAPPER
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
#include "opencv2/highgui/highgui.hpp"
9+
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
10+
11+
#include <iostream>
12+
#include <memory>
13+
14+
class PyTorchCppWrapperSegTrav : public PyTorchCppWrapperBase {
15+
private :
16+
// c = P(s|y=1) in PU learning, calculated during training
17+
float c_{0.3};
18+
19+
public:
20+
std::tuple<at::Tensor, at::Tensor> get_output(at::Tensor input_tensor);
21+
};
22+
#endif
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef PYTORCH_CPP_WRAPPER
2+
#define PYTORCH_CPP_WRAPPER
3+
4+
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
7+
#include <opencv2/opencv.hpp>
8+
#include "opencv2/highgui/highgui.hpp"
9+
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
10+
11+
#include <iostream>
12+
#include <memory>
13+
14+
15+
class PyTorchCppWrapperSegTravPath : public PyTorchCppWrapperBase {
16+
private :
17+
// c = P(s|y=1) in PU learning, calculated during training
18+
float c_{0.3};
19+
20+
public:
21+
std::tuple<at::Tensor, at::Tensor, at::Tensor> get_output(at::Tensor input_tensor);
22+
};
23+
#endif

include/pytorch_enet_ros/pytorch_enet_ros.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
//#include<semantic_segmentation_srvs/GetLabelImage.h>
1616
#include<semantic_segmentation_srvs/GetLabelAndProbability.h>
1717

18-
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
18+
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h"
1919

2020
#include <iostream>
2121
#include <memory>
@@ -34,7 +34,7 @@ class PyTorchENetROS {
3434
image_transport::Publisher pub_color_image_;
3535
image_transport::Publisher pub_prob_image_;
3636

37-
PyTorchCppWrapper pt_wrapper_;
37+
PyTorchCppWrapperSegTrav pt_wrapper_;
3838

3939
cv::Mat colormap_;
4040

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
/*
2+
* A ROS node to do inference using PyTorch model
3+
* Shigemichi Matsuzaki
4+
*
5+
*/
6+
7+
#ifndef PYTORCH_SEG_TRAV_PATH
8+
#define PYTORCH_SEG_TRAV_PATH
9+
10+
#include <ros/ros.h>
11+
12+
#include <opencv2/opencv.hpp>
13+
#include<image_transport/image_transport.h>
14+
#include<cv_bridge/cv_bridge.h>
15+
#include<geometry_msgs/Point.h>
16+
//#include<semantic_segmentation_srvs/GetLabelImage.h>
17+
#include<semantic_segmentation_srvs/GetLabelAndProbability.h>
18+
19+
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav_path.h"
20+
21+
#include <iostream>
22+
#include <memory>
23+
#include <tuple>
24+
25+
class PyTorchSegTravPathROS {
26+
private:
27+
ros::NodeHandle nh_;
28+
29+
ros::ServiceServer get_label_image_server_;
30+
31+
image_transport::ImageTransport it_;
32+
33+
image_transport::Subscriber sub_image_;
34+
image_transport::Publisher pub_label_image_;
35+
image_transport::Publisher pub_color_image_;
36+
image_transport::Publisher pub_prob_image_;
37+
38+
// PyTorchCppWrapper pt_wrapper_;
39+
PyTorchCppWrapperSegTravPath pt_wrapper_;
40+
41+
cv::Mat colormap_;
42+
43+
public:
44+
PyTorchSegTravPathROS(ros::NodeHandle & nh);
45+
46+
void image_callback(const sensor_msgs::ImageConstPtr& msg);
47+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, geometry_msgs::PointPtr, geometry_msgs::PointPtr> inference(cv::Mat & input_image);
48+
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req,
49+
semantic_segmentation_srvs::GetLabelAndProbability::Response & res);
50+
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg);
51+
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg);
52+
void label_to_color(cv::Mat& label, cv::Mat& color_label);
53+
};
54+
55+
#endif

launch/pytorch_enet_ros.launch

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
<launch>
33
<arg name="image" default="/camera/rgb/image_rect_color" />
44

5-
<node pkg="pytorch_enet_ros" type="pytorch_enet_ros_node" name="pytorch_enet_ros_node" output="screen">
5+
<node pkg="pytorch_enet_ros" type="pytorch_seg_trav_path_node" name="pytorch_seg_trav_path_node" output="screen">
66
<!--
77
<remap from="~image" to="/kinect2/qhd/image_color_rect" />
88
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_camvid.pt" />
99
<param name="colormap" value="$(find pytorch_enet_ros)/images/camvid12.png" />
1010
-->
1111
<remap from="~image" to="$(arg image)" />
12-
<param name="model_file" value="$(find pytorch_enet_ros)/models/espdnet_ue_trav_20210115-151110.pt" />
12+
<param name="model_file" value="$(find pytorch_enet_ros)/models/espdnet_ue_trav_path_20210518-221714.pt" />
1313
<!--
1414
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_greenhouse.pt" />
1515
-->

src/pytorch_cpp_wrapper.cpp renamed to src/pytorch_cpp_wrapper_base.cpp

Lines changed: 8 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -6,41 +6,28 @@
66

77

88
#include <torch/torch.h>
9-
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
9+
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
1010
#include <torch/script.h> // One-stop header.
1111
#include <torch/data/transforms/tensor.h> // One-stop header.
1212
#include <c10/util/ArrayRef.h>
1313
#include <opencv2/opencv.hpp>
1414
#include "opencv2/highgui/highgui.hpp"
1515
#include <typeinfo>
1616

17-
//namespace mpl {
17+
PyTorchCppWrapperBase::PyTorchCppWrapperBase() {}
1818

19-
PyTorchCppWrapper::PyTorchCppWrapper() {
20-
std::vector<float> mean_vec{0.485, 0.456, 0.406};
21-
std::vector<float> std_vec{0.229, 0.224, 0.225};
22-
23-
}
24-
25-
PyTorchCppWrapper::PyTorchCppWrapper(const std::string filename) {
19+
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const std::string & filename) {
2620
// Import
2721
import_module(filename);
28-
std::vector<float> mean_vec{0.485, 0.456, 0.406};
29-
std::vector<float> std_vec{0.229, 0.224, 0.225};
30-
3122
}
3223

33-
PyTorchCppWrapper::PyTorchCppWrapper(const char* filename) {
24+
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const char* filename) {
3425
// Import
3526
import_module(std::string(filename));
36-
37-
std::vector<float> mean_vec{0.485, 0.456, 0.406};
38-
std::vector<float> std_vec{0.229, 0.224, 0.225};
39-
4027
}
4128

4229
bool
43-
PyTorchCppWrapper::import_module(const std::string filename)
30+
PyTorchCppWrapperBase::import_module(const std::string & filename)
4431
{
4532
try {
4633
// Deserialize the ScriptModule from a file using torch::jit::load().
@@ -59,7 +46,7 @@ PyTorchCppWrapper::import_module(const std::string filename)
5946
}
6047

6148
void
62-
PyTorchCppWrapper::img2tensor(cv::Mat & img, at::Tensor & tensor, const bool use_gpu)
49+
PyTorchCppWrapperBase::img2tensor(cv::Mat & img, at::Tensor & tensor, const bool & use_gpu)
6350
{
6451
// Get the size of the input image
6552
int height = img.size().height;
@@ -77,7 +64,7 @@ PyTorchCppWrapper::img2tensor(cv::Mat & img, at::Tensor & tensor, const bool use
7764
}
7865

7966
void
80-
PyTorchCppWrapper::tensor2img(at::Tensor tensor, cv::Mat & img)
67+
PyTorchCppWrapperBase::tensor2img(at::Tensor tensor, cv::Mat & img)
8168
{
8269
// Get the size of the input image
8370
int height = tensor.sizes()[0];
@@ -89,36 +76,11 @@ PyTorchCppWrapper::tensor2img(at::Tensor tensor, cv::Mat & img)
8976
img = cv::Mat(height, width, CV_8U, tensor. template data<uint8_t>());
9077
}
9178

92-
//at::Tensor
93-
std::tuple<at::Tensor, at::Tensor>
94-
PyTorchCppWrapper::get_output(at::Tensor input_tensor)
95-
{
96-
// Execute the model and turn its output into a tensor.
97-
auto outputs_tmp = module_.forward({input_tensor}); //.toTuple();
98-
99-
auto outputs = outputs_tmp.toTuple();
100-
101-
at::Tensor output1 = outputs->elements()[0].toTensor();
102-
at::Tensor output2 = outputs->elements()[1].toTensor();
103-
at::Tensor prob = outputs->elements()[2].toTensor();
104-
105-
// Divide probability by c
106-
prob = torch::sigmoid(prob) / 0.3;
107-
// Limit the values in range [0, 1]
108-
prob = at::clamp(prob, 0.0, 1.0);
109-
110-
// return output1 + 0.5 * output2;
111-
at::Tensor segmentation = output1 + 0.5 * output2;
112-
113-
return std::forward_as_tuple(segmentation, prob);
114-
}
115-
11679
at::Tensor
117-
PyTorchCppWrapper::get_argmax(at::Tensor input_tensor)
80+
PyTorchCppWrapperBase::get_argmax(at::Tensor input_tensor)
11881
{
11982
// Calculate argmax to get a label on each pixel
12083
at::Tensor output = at::argmax(input_tensor, 1).to(torch::kCPU).to(at::kByte);
12184

12285
return output;
12386
}
124-
//} // namespace mpl

0 commit comments

Comments
 (0)