Skip to content

Commit bab6f13

Browse files
Add some comments
1 parent 5e4be21 commit bab6f13

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/pytorch_seg_trav_path_ros.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,6 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
124124
at::Tensor prob;
125125
at::Tensor points;
126126
std::tie(segmentation, prob, points) = pt_wrapper_.get_output(input_tensor);
127-
// at::Tensor output = pt_wrapper_.get_output(input_tensor);
128-
// Calculate argmax to get a label on each pixel
129-
// at::Tensor output_args = pt_wrapper_.get_argmax(output);
130127

131128
at::Tensor output_args = pt_wrapper_.get_argmax(segmentation);
132129

@@ -155,16 +152,29 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
155152
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg);
156153
}
157154

155+
/**
156+
* @brief Convert a tensor with a size of (1, 4) to start and end points (x, y)
157+
* @param[in] point_tensor (1, 4) tensor
158+
* @param[in] width Original width of the image
159+
* @param[in] height Original height of the image
160+
* @return A tuple of start and end points as geometry_msgs::PointStampedPtr
161+
*/
158162
std::tuple<geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
159163
PyTorchSegTravPathROS::tensor_to_points(const at::Tensor point_tensor, const int & width, const int & height) {
160164
geometry_msgs::PointStampedPtr start_point_msg(new geometry_msgs::PointStamped), end_point_msg(new geometry_msgs::PointStamped);
165+
// Important: put the data on the CPU before accessing the data.
166+
// Absense of this code will result in runtime error.
161167
at::Tensor points = point_tensor.to(torch::kCPU);
162168
auto points_a = points.accessor<float, 2>();
169+
170+
// Initialize messgaes
163171
start_point_msg->header.stamp = ros::Time::now();
164172
start_point_msg->header.frame_id = "kinect2_rgb_optical_frame";
165173
end_point_msg->header.stamp = ros::Time::now();
166174
end_point_msg->header.frame_id = "kinect2_rgb_optical_frame";
167-
start_point_msg->point.x = points_a[0][0] * width;
175+
// Point tensor has coordinate values normalized with the width and height.
176+
// Therefore each value is multiplied by width or height.
177+
start_point_mse->point.x = points_a[0][0] * width;
168178
start_point_msg->point.y = points_a[0][1] * height;
169179
end_point_msg->point.x = points_a[0][2] * width;
170180
end_point_msg->point.y = points_a[0][3] * height;

0 commit comments

Comments
 (0)