@@ -18,7 +18,7 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
18
18
std::string filename;
19
19
nh_.param <std::string>(" model_file" , filename, " " );
20
20
if (!pt_wrapper_.import_module (filename)) {
21
- ROS_ERROR (" Failed to import the model file [%s]" , filename.c_str ());
21
+ ROS_ERROR (" Failed to import the model file [%s]" , filename.c_str ());
22
22
ros::shutdown ();
23
23
}
24
24
@@ -44,6 +44,7 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
44
44
// Run inference
45
45
sensor_msgs::ImagePtr label_msg;
46
46
sensor_msgs::ImagePtr color_label_msg;
47
+ sensor_msgs::ImagePtr prob_msg;
47
48
std::tie (label_msg, color_label_msg) = inference (cv_ptr->image );
48
49
49
50
// Set header
@@ -69,6 +70,7 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
69
70
// Run inference
70
71
sensor_msgs::ImagePtr label_msg;
71
72
sensor_msgs::ImagePtr color_label_msg;
73
+ sensor_msgs::ImagePtr prob_msg;
72
74
std::tie (label_msg, color_label_msg) = inference (cv_ptr->image );
73
75
74
76
res.label_img = *label_msg;
@@ -103,27 +105,35 @@ PyTorchENetROS::inference(cv::Mat & input_img)
103
105
for (int i = 0 ; i < mean_vec.size (); i++) {
104
106
input_tensor[0 ][i] = (input_tensor[0 ][i] - mean_vec[i]) / std_vec[i];
105
107
}
106
- std::cout << input_tensor.sizes () << std::endl;
108
+ // std::cout << input_tensor.sizes() << std::endl;
107
109
108
110
// Execute the model and turn its output into a tensor.
109
- at::Tensor output = pt_wrapper_.get_output (input_tensor);
111
+ at::Tensor segmentation;
112
+ at::Tensor prob;
113
+ std::tie (segmentation, prob) = pt_wrapper_.get_output (input_tensor);
114
+ // at::Tensor output = pt_wrapper_.get_output(input_tensor);
110
115
// Calculate argmax to get a label on each pixel
111
- at::Tensor output_args = pt_wrapper_.get_argmax (output);
116
+ // at::Tensor output_args = pt_wrapper_.get_argmax(output);
117
+
118
+ at::Tensor output_args = pt_wrapper_.get_argmax (segmentation);
112
119
113
120
// Convert to OpenCV
114
121
cv::Mat label;
122
+ cv::Mat prob_cv;
115
123
pt_wrapper_.tensor2img (output_args[0 ], label);
124
+ pt_wrapper_.tensor2img ((prob[0 ][0 ]*255 ).to (torch::kByte ), prob_cv);
116
125
117
126
// Set the size
118
127
cv::Size s_orig (width_orig, height_orig);
119
128
// Resize the input image back to the original size
120
129
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
121
-
130
+ cv::resize (prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
122
131
cv::Mat color_label;
123
132
label_to_color (label, color_label);
124
133
125
134
// Generate an image message
126
- sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
135
+ // sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
136
+ sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
127
137
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
128
138
129
139
return std::forward_as_tuple (label_msg, color_label_msg);
0 commit comments