@@ -12,6 +12,7 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
12
12
sub_image_ = it_.subscribe (" image" , 1 , &PyTorchENetROS::image_callback, this );
13
13
pub_label_image_ = it_.advertise (" label" , 1 );
14
14
pub_color_image_ = it_.advertise (" color_label" , 1 );
15
+ pub_prob_image_ = it_.advertise (" prob" , 1 );
15
16
get_label_image_server_ = nh_.advertiseService (" get_label_image" , &PyTorchENetROS::image_inference_srv_callback, this );
16
17
17
18
// Import the model
@@ -45,22 +46,24 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
45
46
sensor_msgs::ImagePtr label_msg;
46
47
sensor_msgs::ImagePtr color_label_msg;
47
48
sensor_msgs::ImagePtr prob_msg;
48
- std::tie (label_msg, color_label_msg) = inference (cv_ptr->image );
49
+ std::tie (label_msg, color_label_msg, prob_msg ) = inference (cv_ptr->image );
49
50
50
51
// Set header
51
52
label_msg->header = msg->header ;
52
53
color_label_msg->header = msg->header ;
54
+ prob_msg->header = msg->header ;
53
55
54
56
pub_label_image_.publish (label_msg);
55
57
pub_color_image_.publish (color_label_msg);
58
+ pub_prob_image_.publish (prob_msg);
56
59
}
57
60
58
61
/*
59
62
* image_inference_srv_callback : Callback for the service
60
63
*/
61
64
bool
62
- PyTorchENetROS::image_inference_srv_callback (semantic_segmentation_srvs::GetLabelImage ::Request & req,
63
- semantic_segmentation_srvs::GetLabelImage ::Response & res)
65
+ PyTorchENetROS::image_inference_srv_callback (semantic_segmentation_srvs::GetLabelAndProbability ::Request & req,
66
+ semantic_segmentation_srvs::GetLabelAndProbability ::Response & res)
64
67
{
65
68
ROS_INFO (" [PyTorchENetROS image_inference_srv_callback] Start" );
66
69
@@ -71,18 +74,19 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
71
74
sensor_msgs::ImagePtr label_msg;
72
75
sensor_msgs::ImagePtr color_label_msg;
73
76
sensor_msgs::ImagePtr prob_msg;
74
- std::tie (label_msg, color_label_msg) = inference (cv_ptr->image );
77
+ std::tie (label_msg, color_label_msg, prob_msg ) = inference (cv_ptr->image );
75
78
76
79
res.label_img = *label_msg;
77
80
res.colorlabel_img = *color_label_msg;
81
+ res.prob_img = *prob_msg;
78
82
79
83
return true ;
80
84
}
81
85
82
86
/*
83
87
* inference : Forward the given input image through the network and return the inference result
84
88
*/
85
- std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
89
+ std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr >
86
90
PyTorchENetROS::inference (cv::Mat & input_img)
87
91
{
88
92
@@ -128,15 +132,16 @@ PyTorchENetROS::inference(cv::Mat & input_img)
128
132
// Resize the input image back to the original size
129
133
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
130
134
cv::resize (prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
135
+ // Generate color label image
131
136
cv::Mat color_label;
132
137
label_to_color (label, color_label);
133
138
134
139
// Generate an image message
135
- // sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
136
- sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
140
+ sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
137
141
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
142
+ sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
138
143
139
- return std::forward_as_tuple (label_msg, color_label_msg);
144
+ return std::forward_as_tuple (label_msg, color_label_msg, prob_msg );
140
145
}
141
146
142
147
/*
0 commit comments