@@ -124,9 +124,6 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
124
124
at::Tensor prob;
125
125
at::Tensor points;
126
126
std::tie (segmentation, prob, points) = pt_wrapper_.get_output (input_tensor);
127
- // at::Tensor output = pt_wrapper_.get_output(input_tensor);
128
- // Calculate argmax to get a label on each pixel
129
- // at::Tensor output_args = pt_wrapper_.get_argmax(output);
130
127
131
128
at::Tensor output_args = pt_wrapper_.get_argmax (segmentation);
132
129
@@ -155,16 +152,29 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
155
152
return std::forward_as_tuple (label_msg, color_label_msg, prob_msg, start_point_msg, end_point_msg);
156
153
}
157
154
155
+ /* *
156
+ * @brief Convert a tensor with a size of (1, 4) to start and end points (x, y)
157
+ * @param[in] point_tensor (1, 4) tensor
158
+ * @param[in] width Original width of the image
159
+ * @param[in] height Original height of the image
160
+ * @return A tuple of start and end points as geometry_msgs::PointStampedPtr
161
+ */
158
162
std::tuple<geometry_msgs::PointStampedPtr, geometry_msgs::PointStampedPtr>
159
163
PyTorchSegTravPathROS::tensor_to_points (const at::Tensor point_tensor, const int & width, const int & height) {
160
164
geometry_msgs::PointStampedPtr start_point_msg (new geometry_msgs::PointStamped), end_point_msg (new geometry_msgs::PointStamped);
165
+ // Important: put the data on the CPU before accessing the data.
166
+ // Absense of this code will result in runtime error.
161
167
at::Tensor points = point_tensor.to (torch::kCPU );
162
168
auto points_a = points.accessor <float , 2 >();
169
+
170
+ // Initialize messgaes
163
171
start_point_msg->header .stamp = ros::Time::now ();
164
172
start_point_msg->header .frame_id = " kinect2_rgb_optical_frame" ;
165
173
end_point_msg->header .stamp = ros::Time::now ();
166
174
end_point_msg->header .frame_id = " kinect2_rgb_optical_frame" ;
167
- start_point_msg->point .x = points_a[0 ][0 ] * width;
175
+ // Point tensor has coordinate values normalized with the width and height.
176
+ // Therefore each value is multiplied by width or height.
177
+ start_point_mse->point .x = points_a[0 ][0 ] * width;
168
178
start_point_msg->point .y = points_a[0 ][1 ] * height;
169
179
end_point_msg->point .x = points_a[0 ][2 ] * width;
170
180
end_point_msg->point .y = points_a[0 ][3 ] * height;
0 commit comments