Skip to content

Commit cc6e38b

Browse files
[fix] Fix the problem where the value of prob/uncertainty is sometimes substituted to the variable of another
1 parent a431adc commit cc6e38b

File tree

3 files changed

+39
-10
lines changed

3 files changed

+39
-10
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,13 @@ protected :
4545
*/
4646
void tensor2img(at::Tensor tensor, cv::Mat & img);
4747

48+
/**
49+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
50+
* @param[in] tensor
51+
* @return converted CV image
52+
*/
53+
cv::Mat tensor2img(at::Tensor tensor);
54+
4855
/**
4956
* @brief Take element-wise argmax
5057
* @param[in] tensor

src/pytorch_cpp_wrapper_base.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,28 @@ PyTorchCppWrapperBase::tensor2img(at::Tensor tensor, cv::Mat & img)
106106
int height = tensor.sizes()[0];
107107
int width = tensor.sizes()[1];
108108

109-
tensor = tensor.to(torch::kCPU);
109+
// tensor = tensor.to(torch::kCPU);
110110

111111
// Convert to OpenCV
112112
img = cv::Mat(height, width, CV_8U, tensor. template data<uint8_t>());
113113
}
114114

115+
/**
116+
* @brief convert a tensor (at::Tensor) to an image (cv::Mat)
117+
* @param[in] tensor
118+
* @return converted CV image
119+
*/
120+
cv::Mat
121+
PyTorchCppWrapperBase::tensor2img(at::Tensor tensor)
122+
{
123+
// Get the size of the input image
124+
int height = tensor.sizes()[0];
125+
int width = tensor.sizes()[1];
126+
127+
// Convert to OpenCV
128+
return cv::Mat(height, width, CV_8U, tensor. template data<uint8_t>());
129+
}
130+
115131
/**
116132
* @brief Take element-wise argmax
117133
* @param[in] tensor

src/pytorch_seg_trav_path_ros.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,26 +137,31 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
137137
// prob: traversability
138138
// points: coordinates of the line points
139139
std::tie(segmentation, prob, points) = pt_wrapper_ptr_->get_output(input_tensor);
140+
prob = (prob[0][0]*255).to(torch::kCPU).to(torch::kByte);
140141

141142
// Get class label map by taking argmax of 'segmentation'
142143
at::Tensor output_args = pt_wrapper_ptr_->get_argmax(segmentation);
143144

144145
// Uncertainty of segmentation
145146
at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy(segmentation, true);
147+
uncertainty = (uncertainty[0]*255).to(torch::kCPU).to(torch::kByte);
148+
// at::Tensor uncertainty = torch::zeros_like(prob[0]);
149+
150+
// Set the size
151+
cv::Size s_orig(width_orig, height_orig);
146152

147153
// Convert to OpenCV
148154
cv::Mat label;
149155
cv::Mat prob_cv;
150-
cv::Mat uncertainty_cv;
156+
cv::Mat uncertainty_cv = cv::Mat::zeros(s_orig.height, s_orig.width, CV_8U);
151157
// Segmentation label
152-
pt_wrapper_ptr_->tensor2img(output_args[0], label);
153-
// Traverability
154-
pt_wrapper_ptr_->tensor2img((prob[0][0]*255).to(torch::kByte), prob_cv);
158+
label = pt_wrapper_ptr_->tensor2img(output_args[0]);
155159
// Segmentation label
156-
pt_wrapper_ptr_->tensor2img((uncertainty[0]*255).to(torch::kByte), uncertainty_cv);
160+
uncertainty_cv = pt_wrapper_ptr_->tensor2img(uncertainty);
161+
// uncertainty_cv = pt_wrapper_ptr_->tensor2img((uncertainty*255).to(torch::kCPU).to(torch::kByte));
162+
// Traverability
163+
prob_cv = pt_wrapper_ptr_->tensor2img(prob);
157164

158-
// Set the size
159-
cv::Size s_orig(width_orig, height_orig);
160165
// Resize the input image back to the original size
161166
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
162167
cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
@@ -169,6 +174,7 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
169174
// Generate an image message and point messages
170175
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
171176
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
177+
// Problem: Wrong data is sometimes assigned to 'prob_cv'
172178
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
173179
sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", uncertainty_cv).toImageMsg();
174180
geometry_msgs::PointStampedPtr start_point_msg(new geometry_msgs::PointStamped), end_point_msg(new geometry_msgs::PointStamped);
@@ -233,7 +239,7 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg)
233239
// Convert the image message to a cv_bridge object
234240
try
235241
{
236-
cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
242+
cv_ptr = cv_bridge::toCvCopy(msg, msg->encoding);
237243
}
238244
catch (cv_bridge::Exception& e)
239245
{
@@ -257,7 +263,7 @@ PyTorchSegTravPathROS::msg_to_cv_bridge(sensor_msgs::Image msg)
257263
// Convert the image message to a cv_bridge object
258264
try
259265
{
260-
cv_ptr = cv_bridge::toCvCopy(msg, sensor_msgs::image_encodings::BGR8);
266+
cv_ptr = cv_bridge::toCvCopy(msg, msg.encoding);
261267
}
262268
catch (cv_bridge::Exception& e)
263269
{

0 commit comments

Comments
 (0)