@@ -137,26 +137,31 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
137
137
// prob: traversability
138
138
// points: coordinates of the line points
139
139
std::tie (segmentation, prob, points) = pt_wrapper_ptr_->get_output (input_tensor);
140
+ prob = (prob[0 ][0 ]*255 ).to (torch::kCPU ).to (torch::kByte );
140
141
141
142
// Get class label map by taking argmax of 'segmentation'
142
143
at::Tensor output_args = pt_wrapper_ptr_->get_argmax (segmentation);
143
144
144
145
// Uncertainty of segmentation
145
146
at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy (segmentation, true );
147
+ uncertainty = (uncertainty[0 ]*255 ).to (torch::kCPU ).to (torch::kByte );
148
+ // at::Tensor uncertainty = torch::zeros_like(prob[0]);
149
+
150
+ // Set the size
151
+ cv::Size s_orig (width_orig, height_orig);
146
152
147
153
// Convert to OpenCV
148
154
cv::Mat label;
149
155
cv::Mat prob_cv;
150
- cv::Mat uncertainty_cv;
156
+ cv::Mat uncertainty_cv = cv::Mat::zeros (s_orig. height , s_orig. width , CV_8U) ;
151
157
// Segmentation label
152
- pt_wrapper_ptr_->tensor2img (output_args[0 ], label);
153
- // Traverability
154
- pt_wrapper_ptr_->tensor2img ((prob[0 ][0 ]*255 ).to (torch::kByte ), prob_cv);
158
+ label = pt_wrapper_ptr_->tensor2img (output_args[0 ]);
155
159
// Segmentation label
156
- pt_wrapper_ptr_->tensor2img ((uncertainty[0 ]*255 ).to (torch::kByte ), uncertainty_cv);
160
+ uncertainty_cv = pt_wrapper_ptr_->tensor2img (uncertainty);
161
+ // uncertainty_cv = pt_wrapper_ptr_->tensor2img((uncertainty*255).to(torch::kCPU).to(torch::kByte));
162
+ // Traverability
163
+ prob_cv = pt_wrapper_ptr_->tensor2img (prob);
157
164
158
- // Set the size
159
- cv::Size s_orig (width_orig, height_orig);
160
165
// Resize the input image back to the original size
161
166
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
162
167
cv::resize (prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
@@ -169,6 +174,7 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
169
174
// Generate an image message and point messages
170
175
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
171
176
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
177
+ // Problem: Wrong data is sometimes assigned to 'prob_cv'
172
178
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
173
179
sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , uncertainty_cv).toImageMsg ();
174
180
geometry_msgs::PointStampedPtr start_point_msg (new geometry_msgs::PointStamped), end_point_msg (new geometry_msgs::PointStamped);
@@ -233,7 +239,7 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg)
233
239
// Convert the image message to a cv_bridge object
234
240
try
235
241
{
236
- cv_ptr = cv_bridge::toCvCopy (msg, sensor_msgs::image_encodings::BGR8 );
242
+ cv_ptr = cv_bridge::toCvCopy (msg, msg-> encoding );
237
243
}
238
244
catch (cv_bridge::Exception& e)
239
245
{
@@ -257,7 +263,7 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::Image msg)
257
263
// Convert the image message to a cv_bridge object
258
264
try
259
265
{
260
- cv_ptr = cv_bridge::toCvCopy (msg, sensor_msgs::image_encodings::BGR8 );
266
+ cv_ptr = cv_bridge::toCvCopy (msg, msg. encoding );
261
267
}
262
268
catch (cv_bridge::Exception& e)
263
269
{
0 commit comments