Skip to content

Commit c3f494c

Browse files
[update] Update 'PyTorchSegTrav' to output uncertainty
1 parent 592a2ec commit c3f494c

File tree

6 files changed

+55
-15
lines changed

6 files changed

+55
-15
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ private :
1717
float c_{0.3};
1818

1919
public:
20+
PyTorchCppWrapperSegTrav(const std::string & filename, const int class_num);
21+
PyTorchCppWrapperSegTrav(const char* filename, const int class_num);
22+
2023
/**
2124
* @brief Get outputs from the model
2225
* @param[in] input_tensor Input tensor

include/pytorch_ros/pytorch_enet_ros.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,18 @@ class PyTorchENetROS {
3333
image_transport::Publisher pub_label_image_;
3434
image_transport::Publisher pub_color_image_;
3535
image_transport::Publisher pub_prob_image_;
36+
image_transport::Publisher pub_uncertainty_image_;
3637

37-
PyTorchCppWrapperSegTrav pt_wrapper_;
38+
// PyTorchCppWrapperSegTrav pt_wrapper_;
39+
std::shared_ptr<PyTorchCppWrapperSegTrav> pt_wrapper_ptr_;
3840

3941
cv::Mat colormap_;
4042

4143
public:
4244
PyTorchENetROS(ros::NodeHandle & nh);
4345

4446
void image_callback(const sensor_msgs::ImageConstPtr& msg);
45-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
47+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
4648
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req,
4749
semantic_segmentation_srvs::GetLabelAndProbability::Response & res);
4850
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg);

launch/pytorch_enet_ros.launch

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
<?xml version="1.0"?>
22
<launch>
33
<arg name="image" default="/camera/rgb/image_rect_color" />
4+
<!--
45
<arg name="model_name" default="$(find pytorch_ros)/models/espdnet_ue_trav_path_20210712-134315.pt" />
6+
-->
7+
<arg name="model_name" default="$(find pytorch_ros)/models/espdnet_ue_trav_20210115-151110.pt" />
58

9+
<!--
610
<node pkg="pytorch_ros" type="pytorch_seg_trav_path_node" name="pytorch_seg_trav_path_node" output="screen">
711
<remap from="~image" to="$(arg image)" />
812
<param name="model_file" value="$(arg model_name)" />
913
<param name="colormap" value="$(find pytorch_ros)/images/greenhouse4.png" />
1014
<param name="model_name" value="greenhouse" />
1115
</node>
16+
-->
17+
<node pkg="pytorch_ros" type="pytorch_seg_trav_node" name="pytorch_seg_trav_node" output="screen">
18+
<remap from="~image" to="$(arg image)" />
19+
<param name="model_file" value="$(arg model_name)" />
20+
<param name="colormap" value="$(find pytorch_ros)/images/greenhouse4.png" />
21+
<param name="model_name" value="greenhouse" />
22+
</node>
1223

1324
<node pkg="pytorch_ros" type="visualizer.py" name="visualizer" output="screen">
1425
<remap from="image" to="$(arg image)" />

src/pytorch_cpp_wrapper_seg_trav.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,14 @@
1313
#include "opencv2/highgui/highgui.hpp"
1414
#include <typeinfo>
1515

16+
PyTorchCppWrapperSegTrav::PyTorchCppWrapperSegTrav(const std::string & filename, const int class_num)
17+
: PyTorchCppWrapperBase(filename, class_num)
18+
{ }
19+
20+
PyTorchCppWrapperSegTrav::PyTorchCppWrapperSegTrav(const char* filename, const int class_num)
21+
: PyTorchCppWrapperBase(filename, class_num)
22+
{ }
23+
1624
/**
1725
* @brief Get outputs from the model
1826
* @param[in] input_tensor Input tensor

src/pytorch_enet_ros.cpp

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
1313
pub_label_image_ = it_.advertise("label", 1);
1414
pub_color_image_ = it_.advertise("color_label", 1);
1515
pub_prob_image_ = it_.advertise("prob", 1);
16+
pub_uncertainty_image_ = it_.advertise("uncertainty", 1);
1617
get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchENetROS::image_inference_srv_callback, this);
1718

1819
// Import the model
1920
std::string filename;
2021
nh_.param<std::string>("model_file", filename, "");
21-
if(!pt_wrapper_.import_module(filename)) {
22+
pt_wrapper_ptr_.reset(new PyTorchCppWrapperSegTrav(filename, 4));
23+
if(!pt_wrapper_ptr_->import_module(filename)) {
2224
ROS_ERROR("Failed to import the model file [%s]", filename.c_str());
2325
ros::shutdown();
2426
}
@@ -46,16 +48,19 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
4648
sensor_msgs::ImagePtr label_msg;
4749
sensor_msgs::ImagePtr color_label_msg;
4850
sensor_msgs::ImagePtr prob_msg;
49-
std::tie(label_msg, color_label_msg, prob_msg) = inference(cv_ptr->image);
51+
sensor_msgs::ImagePtr uncertainty_msg;
52+
std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference(cv_ptr->image);
5053

5154
// Set header
5255
label_msg->header = msg->header;
5356
color_label_msg->header = msg->header;
5457
prob_msg->header = msg->header;
58+
uncertainty_msg->header = msg->header;
5559

5660
pub_label_image_.publish(label_msg);
5761
pub_color_image_.publish(color_label_msg);
5862
pub_prob_image_.publish(prob_msg);
63+
pub_uncertainty_image_.publish(uncertainty_msg);
5964
}
6065

6166
/*
@@ -74,19 +79,21 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
7479
sensor_msgs::ImagePtr label_msg;
7580
sensor_msgs::ImagePtr color_label_msg;
7681
sensor_msgs::ImagePtr prob_msg;
77-
std::tie(label_msg, color_label_msg, prob_msg) = inference(cv_ptr->image);
82+
sensor_msgs::ImagePtr uncertainty_msg;
83+
std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference(cv_ptr->image);
7884

7985
res.label_img = *label_msg;
8086
res.colorlabel_img = *color_label_msg;
8187
res.prob_img = *prob_msg;
88+
res.uncertainty_img = *uncertainty_msg;
8289

8390
return true;
8491
}
8592

8693
/*
8794
* inference : Forward the given input image through the network and return the inference result
8895
*/
89-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
96+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
9097
PyTorchENetROS::inference(cv::Mat & input_img)
9198
{
9299

@@ -99,7 +106,7 @@ PyTorchENetROS::inference(cv::Mat & input_img)
99106
cv::resize(input_img, input_img, s);
100107

101108
at::Tensor input_tensor;
102-
pt_wrapper_.img2tensor(input_img, input_tensor);
109+
pt_wrapper_ptr_->img2tensor(input_img, input_tensor);
103110

104111
// Normalize from [0, 255] -> [0, 1]
105112
input_tensor /= 255.0;
@@ -114,24 +121,33 @@ PyTorchENetROS::inference(cv::Mat & input_img)
114121
// Execute the model and turn its output into a tensor.
115122
at::Tensor segmentation;
116123
at::Tensor prob;
117-
std::tie(segmentation, prob) = pt_wrapper_.get_output(input_tensor);
118-
// at::Tensor output = pt_wrapper_.get_output(input_tensor);
124+
std::tie(segmentation, prob) = pt_wrapper_ptr_->get_output(input_tensor);
125+
prob = (prob[0][0]*255).to(torch::kCPU).to(torch::kByte);
126+
// at::Tensor output = pt_wrapper_ptr_->get_output(input_tensor);
119127
// Calculate argmax to get a label on each pixel
120-
// at::Tensor output_args = pt_wrapper_.get_argmax(output);
128+
// at::Tensor output_args = pt_wrapper_ptr_->get_argmax(output);
121129

122-
at::Tensor output_args = pt_wrapper_.get_argmax(segmentation);
130+
at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation);
131+
132+
// Uncertainty of segmentation
133+
at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true);
134+
// at::Tensor uncertainty = torch::zeros_like(prob);
135+
uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte);
123136

124137
// Convert to OpenCV
125138
cv::Mat label;
126139
cv::Mat prob_cv;
127-
pt_wrapper_.tensor2img(output_args[0], label);
128-
pt_wrapper_.tensor2img((prob[0][0]*255).to(torch::kByte), prob_cv);
140+
cv::Mat uncertainty_cv;
141+
pt_wrapper_ptr_->tensor2img(output_args[0], label);
142+
pt_wrapper_ptr_->tensor2img(prob, prob_cv);
143+
pt_wrapper_ptr_->tensor2img(uncertainty, uncertainty_cv);
129144

130145
// Set the size
131146
cv::Size s_orig(width_orig, height_orig);
132147
// Resize the input image back to the original size
133148
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
134149
cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
150+
cv::resize(uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR);
135151
// Generate color label image
136152
cv::Mat color_label;
137153
label_to_color(label, color_label);
@@ -140,8 +156,9 @@ PyTorchENetROS::inference(cv::Mat & input_img)
140156
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
141157
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
142158
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
159+
sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg();
143160

144-
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg);
161+
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, uncertainty_msg);
145162
}
146163

147164
/*

src/pytorch_seg_trav_path_ros.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
145145
// Uncertainty of segmentation
146146
at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true);
147147
uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte);
148-
// at::Tensor uncertainty = torch::zeros_like(prob[0]);
149148

150149
// Set the size
151150
cv::Size s_orig(width_orig, height_orig);

0 commit comments

Comments
 (0)