Skip to content

Commit a431adc

Browse files
[update] Implement uncertainty calculation based on the entropy
1 parent be7976f commit a431adc

6 files changed

+102
-29
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
class PyTorchCppWrapperBase {
1616
protected :
1717
torch::jit::script::Module module_;
18+
int class_num_;
19+
float max_entropy_;
1820

1921
public:
2022
PyTorchCppWrapperBase();
21-
PyTorchCppWrapperBase(const std::string & filename);
22-
PyTorchCppWrapperBase(const char* filename);
23+
PyTorchCppWrapperBase(const std::string & filename, const int class_num);
24+
PyTorchCppWrapperBase(const char* filename, const int class_num);
2325

2426
/**
2527
* @brief import a network
@@ -55,7 +57,7 @@ protected :
5557
* @param[in] tensor
5658
* @param[out] tensor that has index of max value in each element
5759
*/
58-
at::Tensor get_entropy(at::Tensor input_tensor);
60+
at::Tensor get_entropy(at::Tensor input_tensor, const bool normalize);
5961

6062
};
6163
//}

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_seg_trav_path.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ private :
1818
float c_{0.3};
1919

2020
public:
21+
PyTorchCppWrapperSegTravPath(const std::string & filename, const int class_num);
22+
PyTorchCppWrapperSegTravPath(const char* filename, const int class_num);
23+
2124
/**
2225
* @brief Get outputs from the model
2326
* @param[in] input_tensor Input tensor

include/pytorch_ros/pytorch_seg_trav_path_ros.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,12 @@ class PyTorchSegTravPathROS {
3636
image_transport::Publisher pub_label_image_;
3737
image_transport::Publisher pub_color_image_;
3838
image_transport::Publisher pub_prob_image_;
39+
image_transport::Publisher pub_uncertainty_image_;
3940
ros::Publisher pub_start_point_;
4041
ros::Publisher pub_end_point_;
4142
ros::Time stamp_of_current_image_;
4243

43-
PyTorchCppWrapperSegTravPath pt_wrapper_;
44+
std::shared_ptr<PyTorchCppWrapperSegTravPath> pt_wrapper_ptr_;
4445

4546
// Used to convert a label image to a color image
4647
cv::Mat colormap_;
@@ -59,7 +60,8 @@ class PyTorchSegTravPathROS {
5960
* @param[in] input_image OpenCV image
6061
* @return A tuple of messages of the inference results
6162
*/
62-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr> inference(cv::Mat & input_image);
63+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
64+
inference(cv::Mat & input_image);
6365

6466
/**
6567
* @brief Service callback
@@ -99,6 +101,13 @@ class PyTorchSegTravPathROS {
99101
* @return A tuple of start and end points as geometry_msgs::PointStampedPtr
100102
*/
101103
std::tuple<geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr> tensor_to_points(const at::Tensor point_tensor, const int & width, const int & height);
104+
105+
/**
106+
* @brief Normalize a tensor to feed in a model
107+
* @param[in] input Tensor
108+
*/
109+
void normalize_tensor(at::Tensor & input_tensor);
110+
102111
};
103112

104113
#endif

src/pytorch_cpp_wrapper_base.cpp

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
*
55
*/
66

7-
87
#include <torch/torch.h>
98
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h"
109
#include <torch/script.h> // One-stop header.
@@ -13,17 +12,38 @@
1312
#include <opencv2/opencv.hpp>
1413
#include "opencv2/highgui/highgui.hpp"
1514
#include <typeinfo>
15+
#include <cmath>
1616

1717
PyTorchCppWrapperBase::PyTorchCppWrapperBase() {}
1818

19-
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const std::string & filename) {
19+
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const std::string & filename, const int class_num)
20+
: class_num_(class_num)
21+
{
2022
// Import model
2123
import_module(filename);
24+
25+
// Calculate the maximum possible entropy
26+
// to normalize the entropy value in [0, 1].
27+
max_entropy_ = 0;
28+
const float prob = (float) 1.0 / class_num_;
29+
for(int i = 0; i < class_num_; ++i) {
30+
max_entropy_ += -prob * std::log(prob);
31+
}
2232
}
2333

24-
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const char* filename) {
34+
PyTorchCppWrapperBase::PyTorchCppWrapperBase(const char* filename, const int class_num)
35+
: class_num_(class_num)
36+
{
2537
// Import model
2638
import_module(std::string(filename));
39+
40+
// Calculate the maximum possible entropy
41+
// to normalize the entropy value in [0, 1].
42+
max_entropy_ = 0;
43+
const float prob = (float) 1.0 / class_num_;
44+
for(int i = 0; i < class_num_; ++i) {
45+
max_entropy_ += -prob * std::log(prob);
46+
}
2747
}
2848

2949
/**
@@ -101,7 +121,7 @@ at::Tensor
101121
PyTorchCppWrapperBase::get_argmax(at::Tensor input_tensor)
102122
{
103123
// Calculate argmax to get a label on each pixel
104-
at::Tensor output = at::argmax(input_tensor, 1).to(torch::kCPU).to(at::kByte);
124+
at::Tensor output = at::argmax(input_tensor, /*dim=*/1).to(torch::kCPU).to(at::kByte);
105125

106126
return output;
107127
}
@@ -112,7 +132,7 @@ PyTorchCppWrapperBase::get_argmax(at::Tensor input_tensor)
112132
* @param[out] tensor that has index of max value in each element
113133
*/
114134
at::Tensor
115-
PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor)
135+
PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor, const bool normalize = true)
116136
{
117137
input_tensor.to(torch::kCUDA);
118138
// Calculate the entropy at each pixel
@@ -121,5 +141,11 @@ PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor)
121141

122142
at::Tensor entropy = -torch::sum(p * log_p, /*dim=*/1);
123143

144+
if(normalize)
145+
entropy = entropy / max_entropy_;
146+
124147
return entropy;
125148
}
149+
150+
151+

src/pytorch_cpp_wrapper_seg_trav_path.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@
1414
#include "opencv2/highgui/highgui.hpp"
1515
#include <typeinfo>
1616

17+
PyTorchCppWrapperSegTravPath::PyTorchCppWrapperSegTravPath(const std::string & filename, const int class_num)
18+
: PyTorchCppWrapperBase(filename, class_num)
19+
{ }
20+
21+
PyTorchCppWrapperSegTravPath::PyTorchCppWrapperSegTravPath(const char* filename, const int class_num)
22+
: PyTorchCppWrapperBase(filename, class_num)
23+
{ }
24+
1725
/**
1826
* @brief Get outputs from the model
1927
* @param[in] input_tensor Input tensor

src/pytorch_seg_trav_path_ros.cpp

Lines changed: 44 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,16 @@ PyTorchSegTravPathROS::PyTorchSegTravPathROS(ros::NodeHandle & nh)
1313
pub_label_image_ = it_.advertise("label", 1);
1414
pub_color_image_ = it_.advertise("color_label", 1);
1515
pub_prob_image_ = it_.advertise("prob", 1);
16+
pub_uncertainty_image_ = it_.advertise("uncertainty", 1);
1617
pub_start_point_ = nh_.advertise<geometry_msgs::PointStamped>("start_point", 1);
1718
pub_end_point_ = nh_.advertise<geometry_msgs::PointStamped>("end_point", 1);
1819
get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchSegTravPathROS::image_inference_srv_callback, this);
1920

2021
// Import the model
2122
std::string filename;
2223
nh_.param<std::string>("model_file", filename, "");
23-
if(!pt_wrapper_.import_module(filename)) {
24+
pt_wrapper_ptr_.reset(new PyTorchCppWrapperSegTravPath(filename, 4));
25+
if(!pt_wrapper_ptr_->import_module(filename)) {
2426
ROS_ERROR("Failed to import the model file [%s]", filename.c_str());
2527
ros::shutdown();
2628
}
@@ -53,19 +55,22 @@ PyTorchSegTravPathROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
5355
sensor_msgs::ImagePtr label_msg;
5456
sensor_msgs::ImagePtr color_label_msg;
5557
sensor_msgs::ImagePtr prob_msg;
58+
sensor_msgs::ImagePtr uncertainty_msg;
5659
geometry_msgs::PointStampedPtr start_point_msg;
5760
geometry_msgs::PointStampedPtr end_point_msg;
58-
std::tie(label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image);
61+
std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image);
5962

6063
// Set header
6164
label_msg->header = msg->header;
6265
color_label_msg->header = msg->header;
6366
prob_msg->header = msg->header;
67+
uncertainty_msg->header = msg->header;
6468

6569
// Publish the messages
6670
pub_label_image_.publish(label_msg);
6771
pub_color_image_.publish(color_label_msg);
6872
pub_prob_image_.publish(prob_msg);
73+
pub_uncertainty_image_.publish(uncertainty_msg);
6974
pub_start_point_.publish(start_point_msg);
7075
pub_end_point_.publish(end_point_msg);
7176
}
@@ -88,9 +93,10 @@ PyTorchSegTravPathROS::image_inference_srv_callback(semantic_segmentation_srvs::
8893
sensor_msgs::ImagePtr label_msg;
8994
sensor_msgs::ImagePtr color_label_msg;
9095
sensor_msgs::ImagePtr prob_msg;
96+
sensor_msgs::ImagePtr uncertainty_msg;
9197
geometry_msgs::PointStampedPtr start_point_msg;
9298
geometry_msgs::PointStampedPtr end_point_msg;
93-
std::tie(label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image);
99+
std::tie(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference(cv_ptr->image);
94100

95101
res.label_img = *label_msg;
96102
res.colorlabel_img = *color_label_msg;
@@ -105,7 +111,8 @@ PyTorchSegTravPathROS::image_inference_srv_callback(semantic_segmentation_srvs::
105111
* @param[in] res Response
106112
* @return True if the service succeeded
107113
*/
108-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
114+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr,
115+
geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
109116
PyTorchSegTravPathROS::inference(cv::Mat & input_img)
110117
{
111118

@@ -118,16 +125,9 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
118125
cv::resize(input_img, input_img, s);
119126

120127
at::Tensor input_tensor;
121-
pt_wrapper_.img2tensor(input_img, input_tensor);
128+
pt_wrapper_ptr_->img2tensor(input_img, input_tensor);
122129

123-
// Normalize from [0, 255] -> [0, 1]
124-
input_tensor /= 255.0;
125-
// z-normalization
126-
std::vector<float> mean_vec{0.485, 0.456, 0.406};
127-
std::vector<float> std_vec{0.229, 0.224, 0.225};
128-
for(int i = 0; i < mean_vec.size(); i++) {
129-
input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i];
130-
}
130+
normalize_tensor(input_tensor);
131131

132132
// Execute the model and turn its output into a tensor.
133133
at::Tensor segmentation;
@@ -136,25 +136,32 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
136136
// segmentation: raw output for segmentation (before softmax)
137137
// prob: traversability
138138
// points: coordinates of the line points
139-
std::tie(segmentation, prob, points) = pt_wrapper_.get_output(input_tensor);
139+
std::tie(segmentation, prob, points) = pt_wrapper_ptr_->get_output(input_tensor);
140140

141141
// Get class label map by taking argmax of 'segmentation'
142-
at::Tensor output_args = pt_wrapper_.get_argmax(segmentation);
142+
at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation);
143143

144144
// Uncertainty of segmentation
145-
at::Tensor uncertainty = pt_wrapper_.get_entropy(segmentation);
145+
at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true);
146146

147147
// Convert to OpenCV
148148
cv::Mat label;
149149
cv::Mat prob_cv;
150-
pt_wrapper_.tensor2img(output_args[0], label);
151-
pt_wrapper_.tensor2img((prob[0][0]*255).to(torch::kByte), prob_cv);
150+
cv::Mat uncertainty_cv;
151+
// Segmentation label
152+
pt_wrapper_ptr_->tensor2img(output_args[0], label);
153+
// Traverability
154+
pt_wrapper_ptr_->tensor2img((prob[0][0]*255).to(torch::kByte), prob_cv);
155+
// Segmentation label
156+
pt_wrapper_ptr_->tensor2img((uncertainty[0]*255).to(torch::kByte), uncertainty_cv);
152157

153158
// Set the size
154159
cv::Size s_orig(width_orig, height_orig);
155160
// Resize the input image back to the original size
156161
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
157162
cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
163+
cv::resize(uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR);
164+
158165
// Generate color label image
159166
cv::Mat color_label;
160167
label_to_color(label, color_label);
@@ -163,10 +170,11 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
163170
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
164171
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
165172
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
173+
sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg();
166174
geometry_msgs::PointStampedPtr start_point_msg(new geometry_msgs::PointStamped), end_point_msg(new geometry_msgs::PointStamped);
167175
std::tie(start_point_msg, end_point_msg) = tensor_to_points(points, width_orig, height_orig);
168176

169-
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg);
177+
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg);
170178
}
171179

172180
/**
@@ -259,3 +267,20 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::Image msg)
259267

260268
return cv_ptr;
261269
}
270+
271+
/**
272+
* @brief Normalize a tensor to feed in a model
273+
* @param[in] input Tensor
274+
*/
275+
void
276+
PyTorchSegTravPathROS::normalize_tensor(at::Tensor & input_tensor)
277+
{
278+
// Normalize from [0, 255] -> [0, 1]
279+
input_tensor /= 255.0;
280+
// z-normalization
281+
std::vector<float> mean_vec{0.485, 0.456, 0.406};
282+
std::vector<float> std_vec{0.229, 0.224, 0.225};
283+
for(int i = 0; i < mean_vec.size(); i++) {
284+
input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i];
285+
}
286+
}

0 commit comments

Comments
 (0)