Skip to content

Commit 1b14c50

Browse files
Refactoring
1 parent d8de6a1 commit 1b14c50

File tree

5 files changed

+49
-29
lines changed

5 files changed

+49
-29
lines changed

CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
cmake_minimum_required(VERSION 2.8.3)
22
project(pytorch_enet_ros)
33

4-
## Compile as C++11, supported in ROS Kinetic and newer
5-
add_compile_options(-std=c++11)
4+
add_compile_options(-std=c++14)
65

76
# Locate the cmake file of torchlib
87
set(Torch_DIR "/opt/pytorch/pytorch/torch/share/cmake/Torch/")
@@ -21,6 +20,7 @@ find_package(catkin REQUIRED COMPONENTS
2120
)
2221

2322
find_package(Torch REQUIRED)
23+
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
2424
find_package(OpenCV REQUIRED)
2525

2626
catkin_package(
@@ -54,6 +54,7 @@ target_link_libraries(${PROJECT_NAME}
5454
${Open_CV_LIBS}
5555
opencv_core opencv_highgui opencv_imgcodecs
5656
)
57+
set_property(TARGET ${PROJECT_NAME} PROPERTY CXX_STANDARD 14)
5758

5859
## Declare a C++ executable
5960
## With catkin_make all packages are built within a single CMake context
@@ -64,4 +65,5 @@ add_executable(${PROJECT_NAME}_node src/pytorch_enet_ros_node.cpp)
6465
target_link_libraries(${PROJECT_NAME}_node
6566
${catkin_LIBRARIES}
6667
${PROJECT_NAME}
68+
${TORCH_LIBRARIES}
6769
)

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define PYTORCH_CPP_WRAPPER
33

44
#include <torch/script.h> // One-stop header.
5+
#include <torch/data/transforms/tensor.h> // One-stop header.
6+
#include <c10/util/ArrayRef.h>
57
#include <opencv2/opencv.hpp>
68
#include "opencv2/highgui/highgui.hpp"
79

@@ -11,7 +13,9 @@
1113
//namespace mpl {
1214
class PyTorchCppWrapper {
1315
private :
14-
std::shared_ptr<torch::jit::script::Module> module_;
16+
// std::shared_ptr<torch::jit::script::Module> module_;
17+
torch::jit::script::Module module_;
18+
// torch::data::transforms::Normalize<at::Tensor> normalizer_;
1519

1620
public:
1721
PyTorchCppWrapper();

launch/pytorch_enet_ros.launch

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
<param name="colormap" value="$(find pytorch_enet_ros)/images/camvid12.png" />
1010
-->
1111
<remap from="~image" to="$(arg image)" />
12+
<param name="model_file" value="$(find pytorch_enet_ros)/models/espdnet_ue_uest_trav.pt" />
13+
<!--
1214
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_greenhouse.pt" />
15+
-->
1316
<param name="colormap" value="$(find pytorch_enet_ros)/images/greenhouse4.png" />
1417
<param name="model_name" value="greenhouse" />
1518
</node>

src/pytorch_cpp_wrapper.cpp

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,36 @@
66

77

88
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
9+
#include <torch/script.h> // One-stop header.
10+
#include <torch/data/transforms/tensor.h> // One-stop header.
11+
#include <c10/util/ArrayRef.h>
12+
#include <opencv2/opencv.hpp>
13+
#include "opencv2/highgui/highgui.hpp"
14+
#include <typeinfo>
915

1016
//namespace mpl {
1117

1218
PyTorchCppWrapper::PyTorchCppWrapper() {
13-
19+
std::vector<float> mean_vec{0.485, 0.456, 0.406};
20+
std::vector<float> std_vec{0.229, 0.224, 0.225};
21+
1422
}
1523

1624
PyTorchCppWrapper::PyTorchCppWrapper(const std::string filename) {
1725
// Import
1826
import_module(filename);
27+
std::vector<float> mean_vec{0.485, 0.456, 0.406};
28+
std::vector<float> std_vec{0.229, 0.224, 0.225};
29+
1930
}
2031

2132
PyTorchCppWrapper::PyTorchCppWrapper(const char* filename) {
2233
// Import
2334
import_module(std::string(filename));
35+
36+
std::vector<float> mean_vec{0.485, 0.456, 0.406};
37+
std::vector<float> std_vec{0.229, 0.224, 0.225};
38+
2439
}
2540

2641
bool
@@ -30,8 +45,8 @@ PyTorchCppWrapper::import_module(const std::string filename)
3045
// Deserialize the ScriptModule from a file using torch::jit::load().
3146
module_ = torch::jit::load(filename);
3247
// Set evaluation mode
33-
module_->eval();
34-
std::cout << module_->is_training() << std::endl;
48+
module_.eval();
49+
std::cout << module_.is_training() << std::endl;
3550

3651
std::cout << "Import succeeded" << std::endl;
3752
return true;
@@ -76,9 +91,14 @@ at::Tensor
7691
PyTorchCppWrapper::get_output(at::Tensor input_tensor)
7792
{
7893
// Execute the model and turn its output into a tensor.
79-
at::Tensor output = module_->forward({input_tensor}).toTensor();
94+
auto outputs_tmp = module_.forward({input_tensor}); //.toTuple();
8095

81-
return output;
96+
auto outputs = outputs_tmp.toTuple();
97+
98+
at::Tensor output1 = outputs->elements()[0].toTensor();
99+
at::Tensor output2 = outputs->elements()[1].toTensor();
100+
101+
return output1 + 0.5 * output2;
82102
}
83103

84104
at::Tensor

src/pytorch_enet_ros.cpp

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@
99
PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
1010
: it_(nh), nh_(nh)
1111
{
12-
// nh_ = ros::NodeHandle(nh);
13-
// it_ = image_transport::ImageTransport(nh_);
14-
// ROS_INFO("[PyTorchENetROS] Constructor");
15-
1612
sub_image_ = it_.subscribe("image", 1, &PyTorchENetROS::image_callback, this);
1713
pub_label_image_ = it_.advertise("label", 1);
1814
pub_color_image_ = it_.advertise("color_label", 1);
@@ -67,7 +63,6 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
6763
{
6864
ROS_INFO("[PyTorchENetROS image_inference_srv_callback] Start");
6965

70-
7166
// Convert the image message to a cv_bridge object
7267
cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge(req.img);
7368

@@ -93,28 +88,30 @@ PyTorchENetROS::inference(cv::Mat & input_img)
9388
int height_orig = input_img.size().height;
9489
int width_orig = input_img.size().width;
9590

96-
cv::Size s(480, 264);
91+
cv::Size s(480, 256);
9792
// Resize the input image
9893
cv::resize(input_img, input_img, s);
9994

100-
// ROS_INFO("[PyTorchENetROS inference] Start");
10195
at::Tensor input_tensor;
10296
pt_wrapper_.img2tensor(input_img, input_tensor);
10397

98+
// Normalize from [0, 255] -> [0, 1]
99+
input_tensor /= 255.0;
100+
// z-normalization
101+
std::vector<float> mean_vec{0.485, 0.456, 0.406};
102+
std::vector<float> std_vec{0.229, 0.224, 0.225};
103+
for(int i = 0; i < mean_vec.size(); i++) {
104+
input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i];
105+
}
104106
std::cout << input_tensor.sizes() << std::endl;
107+
105108
// Execute the model and turn its output into a tensor.
106-
// at::Tensor output = module->forward({input_tensor}).toTensor();
107-
// ROS_INFO("[PyTorchENetROS inference] get_output");
108109
at::Tensor output = pt_wrapper_.get_output(input_tensor);
109110
// Calculate argmax to get a label on each pixel
110-
// at::Tensor output_args = at::argmax(output, 1).to(torch::kCPU).to(at::kByte);
111-
// ROS_INFO("[PyTorchENetROS inference] get_argmax");
112111
at::Tensor output_args = pt_wrapper_.get_argmax(output);
113112

114113
// Convert to OpenCV
115-
// cv::Mat mat(height, width, CV_8U, output_args[0]. template data<uint8_t>());
116114
cv::Mat label;
117-
// ROS_INFO("[PyTorchENetROS inference] tensor2img");
118115
pt_wrapper_.tensor2img(output_args[0], label);
119116

120117
// Set the size
@@ -123,18 +120,12 @@ PyTorchENetROS::inference(cv::Mat & input_img)
123120
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
124121

125122
cv::Mat color_label;
126-
// cv::applyColorMap(mat, color_label, cv::COLORMAP_JET);
127123
label_to_color(label, color_label);
128124

129125
// Generate an image message
130-
// ROS_INFO("[PyTorchENetROS inference] cv_bridge to image msg");
131126
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
132127
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
133128

134-
135-
// sensor_msgs::ImagePtr color_msg = cv_bridge::CvImage(std_msgs::Header(), "bgr8", image).toImageMsg();
136-
137-
// ROS_INFO("[PyTorchENetROS inference] Publish");
138129
return std::forward_as_tuple(label_msg, color_label_msg);
139130
}
140131

@@ -155,11 +146,11 @@ cv_bridge::CvImagePtr
155146
PyTorchENetROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg)
156147
{
157148
cv_bridge::CvImagePtr cv_ptr;
149+
158150
// Convert the image message to a cv_bridge object
159151
try
160152
{
161153
cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
162-
// ROS_INFO("[PyTorchENetROS image_callback] Convert to cv_bridge object");
163154
}
164155
catch (cv_bridge::Exception& e)
165156
{
@@ -177,11 +168,11 @@ cv_bridge::CvImagePtr
177168
PyTorchENetROS::msg_to_cv_bridge(sensor_msgs::Image msg)
178169
{
179170
cv_bridge::CvImagePtr cv_ptr;
171+
180172
// Convert the image message to a cv_bridge object
181173
try
182174
{
183175
cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
184-
// ROS_INFO("[PyTorchENetROS image_callback] Convert to cv_bridge object");
185176
}
186177
catch (cv_bridge::Exception& e)
187178
{

0 commit comments

Comments
 (0)