@@ -13,14 +13,16 @@ PyTorchSegTravPathROS::PyTorchSegTravPathROS(ros::NodeHandle & nh)
13
13
pub_label_image_ = it_.advertise (" label" , 1 );
14
14
pub_color_image_ = it_.advertise (" color_label" , 1 );
15
15
pub_prob_image_ = it_.advertise (" prob" , 1 );
16
+ pub_uncertainty_image_ = it_.advertise (" uncertainty" , 1 );
16
17
pub_start_point_ = nh_.advertise <geometry_msgs::PointStamped>(" start_point" , 1 );
17
18
pub_end_point_ = nh_.advertise <geometry_msgs::PointStamped>(" end_point" , 1 );
18
19
get_label_image_server_ = nh_.advertiseService (" get_label_image" , &PyTorchSegTravPathROS::image_inference_srv_callback, this );
19
20
20
21
// Import the model
21
22
std::string filename;
22
23
nh_.param <std::string>(" model_file" , filename, " " );
23
- if (!pt_wrapper_.import_module (filename)) {
24
+ pt_wrapper_ptr_.reset (new PyTorchCppWrapperSegTravPath (filename, 4 ));
25
+ if (!pt_wrapper_ptr_->import_module (filename)) {
24
26
ROS_ERROR (" Failed to import the model file [%s]" , filename.c_str ());
25
27
ros::shutdown ();
26
28
}
@@ -53,19 +55,22 @@ PyTorchSegTravPathROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
53
55
sensor_msgs::ImagePtr label_msg;
54
56
sensor_msgs::ImagePtr color_label_msg;
55
57
sensor_msgs::ImagePtr prob_msg;
58
+ sensor_msgs::ImagePtr uncertainty_msg;
56
59
geometry_msgs::PointStampedPtr start_point_msg;
57
60
geometry_msgs::PointStampedPtr end_point_msg;
58
- std::tie (label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg) = inference (cv_ptr->image );
61
+ std::tie (label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference (cv_ptr->image );
59
62
60
63
// Set header
61
64
label_msg->header = msg->header ;
62
65
color_label_msg->header = msg->header ;
63
66
prob_msg->header = msg->header ;
67
+ uncertainty_msg->header = msg->header ;
64
68
65
69
// Publish the messages
66
70
pub_label_image_.publish (label_msg);
67
71
pub_color_image_.publish (color_label_msg);
68
72
pub_prob_image_.publish (prob_msg);
73
+ pub_uncertainty_image_.publish (uncertainty_msg);
69
74
pub_start_point_.publish (start_point_msg);
70
75
pub_end_point_.publish (end_point_msg);
71
76
}
@@ -88,9 +93,10 @@ PyTorchSegTravPathROS::image_inference_srv_callback(semantic_segmentation_srvs::
88
93
sensor_msgs::ImagePtr label_msg;
89
94
sensor_msgs::ImagePtr color_label_msg;
90
95
sensor_msgs::ImagePtr prob_msg;
96
+ sensor_msgs::ImagePtr uncertainty_msg;
91
97
geometry_msgs::PointStampedPtr start_point_msg;
92
98
geometry_msgs::PointStampedPtr end_point_msg;
93
- std::tie (label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg) = inference (cv_ptr->image );
99
+ std::tie (label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg) = inference (cv_ptr->image );
94
100
95
101
res.label_img = *label_msg;
96
102
res.colorlabel_img = *color_label_msg;
@@ -105,7 +111,8 @@ PyTorchSegTravPathROS::image_inference_srv_callback(semantic_segmentation_srvs::
105
111
* @param[in] res Response
106
112
* @return True if the service succeeded
107
113
*/
108
- std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
114
+ std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr,
115
+ geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
109
116
PyTorchSegTravPathROS::inference (cv::Mat & input_img)
110
117
{
111
118
@@ -118,16 +125,9 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
118
125
cv::resize (input_img, input_img, s);
119
126
120
127
at::Tensor input_tensor;
121
- pt_wrapper_. img2tensor (input_img, input_tensor);
128
+ pt_wrapper_ptr_-> img2tensor (input_img, input_tensor);
122
129
123
- // Normalize from [0, 255] -> [0, 1]
124
- input_tensor /= 255.0 ;
125
- // z-normalization
126
- std::vector<float > mean_vec{0.485 , 0.456 , 0.406 };
127
- std::vector<float > std_vec{0.229 , 0.224 , 0.225 };
128
- for (int i = 0 ; i < mean_vec.size (); i++) {
129
- input_tensor[0 ][i] = (input_tensor[0 ][i] - mean_vec[i]) / std_vec[i];
130
- }
130
+ normalize_tensor (input_tensor);
131
131
132
132
// Execute the model and turn its output into a tensor.
133
133
at::Tensor segmentation;
@@ -136,25 +136,32 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
136
136
// segmentation: raw output for segmentation (before softmax)
137
137
// prob: traversability
138
138
// points: coordinates of the line points
139
- std::tie (segmentation, prob, points) = pt_wrapper_. get_output (input_tensor);
139
+ std::tie (segmentation, prob, points) = pt_wrapper_ptr_-> get_output (input_tensor);
140
140
141
141
// Get class label map by taking argmax of 'segmentation'
142
- at::Tensor output_args = pt_wrapper_. get_argmax (segmentation);
142
+ at::Tensor output_args = pt_wrapper_ptr_-> get_argmax (segmentation);
143
143
144
144
// Uncertainty of segmentation
145
- at::Tensor uncertainty = pt_wrapper_. get_entropy (segmentation);
145
+ at::Tensor uncertainty = pt_wrapper_ptr_-> get_entropy (segmentation, true );
146
146
147
147
// Convert to OpenCV
148
148
cv::Mat label;
149
149
cv::Mat prob_cv;
150
- pt_wrapper_.tensor2img (output_args[0 ], label);
151
- pt_wrapper_.tensor2img ((prob[0 ][0 ]*255 ).to (torch::kByte ), prob_cv);
150
+ cv::Mat uncertainty_cv;
151
+ // Segmentation label
152
+ pt_wrapper_ptr_->tensor2img (output_args[0 ], label);
153
+ // Traverability
154
+ pt_wrapper_ptr_->tensor2img ((prob[0 ][0 ]*255 ).to (torch::kByte ), prob_cv);
155
+ // Segmentation label
156
+ pt_wrapper_ptr_->tensor2img ((uncertainty[0 ]*255 ).to (torch::kByte ), uncertainty_cv);
152
157
153
158
// Set the size
154
159
cv::Size s_orig (width_orig, height_orig);
155
160
// Resize the input image back to the original size
156
161
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
157
162
cv::resize (prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
163
+ cv::resize (uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR);
164
+
158
165
// Generate color label image
159
166
cv::Mat color_label;
160
167
label_to_color (label, color_label);
@@ -163,10 +170,11 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
163
170
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
164
171
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
165
172
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
173
+ sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , uncertainty_cv).toImageMsg ();
166
174
geometry_msgs::PointStampedPtr start_point_msg (new geometry_msgs::PointStamped), end_point_msg (new geometry_msgs::PointStamped);
167
175
std::tie (start_point_msg, end_point_msg) = tensor_to_points (points, width_orig, height_orig);
168
176
169
- return std::forward_as_tuple (label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg);
177
+ return std::forward_as_tuple (label_msg, color_label_msg, prob_msg, uncertainty_msg, start_point_msg, end_point_msg);
170
178
}
171
179
172
180
/* *
@@ -259,3 +267,20 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::Image msg)
259
267
260
268
return cv_ptr;
261
269
}
270
+
271
+ /* *
272
+ * @brief Normalize a tensor to feed in a model
273
+ * @param[in] input Tensor
274
+ */
275
+ void
276
+ PyTorchSegTravPathROS::normalize_tensor (at::Tensor & input_tensor)
277
+ {
278
+ // Normalize from [0, 255] -> [0, 1]
279
+ input_tensor /= 255.0 ;
280
+ // z-normalization
281
+ std::vector<float > mean_vec{0.485 , 0.456 , 0.406 };
282
+ std::vector<float > std_vec{0.229 , 0.224 , 0.225 };
283
+ for (int i = 0 ; i < mean_vec.size (); i++) {
284
+ input_tensor[0 ][i] = (input_tensor[0 ][i] - mean_vec[i]) / std_vec[i];
285
+ }
286
+ }
0 commit comments