@@ -13,12 +13,14 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
13
13
pub_label_image_ = it_.advertise (" label" , 1 );
14
14
pub_color_image_ = it_.advertise (" color_label" , 1 );
15
15
pub_prob_image_ = it_.advertise (" prob" , 1 );
16
+ pub_uncertainty_image_ = it_.advertise (" uncertainty" , 1 );
16
17
get_label_image_server_ = nh_.advertiseService (" get_label_image" , &PyTorchENetROS::image_inference_srv_callback, this );
17
18
18
19
// Import the model
19
20
std::string filename;
20
21
nh_.param <std::string>(" model_file" , filename, " " );
21
- if (!pt_wrapper_.import_module (filename)) {
22
+ pt_wrapper_ptr_.reset (new PyTorchCppWrapperSegTrav (filename, 4 ));
23
+ if (!pt_wrapper_ptr_->import_module (filename)) {
22
24
ROS_ERROR (" Failed to import the model file [%s]" , filename.c_str ());
23
25
ros::shutdown ();
24
26
}
@@ -46,16 +48,19 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
46
48
sensor_msgs::ImagePtr label_msg;
47
49
sensor_msgs::ImagePtr color_label_msg;
48
50
sensor_msgs::ImagePtr prob_msg;
49
- std::tie (label_msg, color_label_msg, prob_msg) = inference (cv_ptr->image );
51
+ sensor_msgs::ImagePtr uncertainty_msg;
52
+ std::tie (label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference (cv_ptr->image );
50
53
51
54
// Set header
52
55
label_msg->header = msg->header ;
53
56
color_label_msg->header = msg->header ;
54
57
prob_msg->header = msg->header ;
58
+ uncertainty_msg->header = msg->header ;
55
59
56
60
pub_label_image_.publish (label_msg);
57
61
pub_color_image_.publish (color_label_msg);
58
62
pub_prob_image_.publish (prob_msg);
63
+ pub_uncertainty_image_.publish (uncertainty_msg);
59
64
}
60
65
61
66
/*
@@ -74,19 +79,21 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
74
79
sensor_msgs::ImagePtr label_msg;
75
80
sensor_msgs::ImagePtr color_label_msg;
76
81
sensor_msgs::ImagePtr prob_msg;
77
- std::tie (label_msg, color_label_msg, prob_msg) = inference (cv_ptr->image );
82
+ sensor_msgs::ImagePtr uncertainty_msg;
83
+ std::tie (label_msg, color_label_msg, prob_msg, uncertainty_msg) = inference (cv_ptr->image );
78
84
79
85
res.label_img = *label_msg;
80
86
res.colorlabel_img = *color_label_msg;
81
87
res.prob_img = *prob_msg;
88
+ res.uncertainty_img = *uncertainty_msg;
82
89
83
90
return true ;
84
91
}
85
92
86
93
/*
87
94
* inference : Forward the given input image through the network and return the inference result
88
95
*/
89
- std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
96
+ std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr >
90
97
PyTorchENetROS::inference (cv::Mat & input_img)
91
98
{
92
99
@@ -99,7 +106,7 @@ PyTorchENetROS::inference(cv::Mat & input_img)
99
106
cv::resize (input_img, input_img, s);
100
107
101
108
at::Tensor input_tensor;
102
- pt_wrapper_. img2tensor (input_img, input_tensor);
109
+ pt_wrapper_ptr_-> img2tensor (input_img, input_tensor);
103
110
104
111
// Normalize from [0, 255] -> [0, 1]
105
112
input_tensor /= 255.0 ;
@@ -114,24 +121,33 @@ PyTorchENetROS::inference(cv::Mat & input_img)
114
121
// Execute the model and turn its output into a tensor.
115
122
at::Tensor segmentation;
116
123
at::Tensor prob;
117
- std::tie (segmentation, prob) = pt_wrapper_.get_output (input_tensor);
118
- // at::Tensor output = pt_wrapper_.get_output(input_tensor);
124
+ std::tie (segmentation, prob) = pt_wrapper_ptr_->get_output (input_tensor);
125
+ prob = (prob[0 ][0 ]*255 ).to (torch::kCPU ).to (torch::kByte );
126
+ // at::Tensor output = pt_wrapper_ptr_->get_output(input_tensor);
119
127
// Calculate argmax to get a label on each pixel
120
- // at::Tensor output_args = pt_wrapper_. get_argmax(output);
128
+ // at::Tensor output_args = pt_wrapper_ptr_-> get_argmax(output);
121
129
122
- at::Tensor output_args = pt_wrapper_.get_argmax (segmentation);
130
+ at::Tensor output_args = pt_wrapper_ptr_->get_argmax (segmentation);
131
+
132
+ // Uncertainty of segmentation
133
+ at::Tensor uncertainty = pt_wrapper_ptr_->get_entropy (segmentation, true );
134
+ // at::Tensor uncertainty = torch::zeros_like(prob);
135
+ uncertainty = (uncertainty[0 ]*255 ).to (torch::kCPU ).to (torch::kByte );
123
136
124
137
// Convert to OpenCV
125
138
cv::Mat label;
126
139
cv::Mat prob_cv;
127
- pt_wrapper_.tensor2img (output_args[0 ], label);
128
- pt_wrapper_.tensor2img ((prob[0 ][0 ]*255 ).to (torch::kByte ), prob_cv);
140
+ cv::Mat uncertainty_cv;
141
+ pt_wrapper_ptr_->tensor2img (output_args[0 ], label);
142
+ pt_wrapper_ptr_->tensor2img (prob, prob_cv);
143
+ pt_wrapper_ptr_->tensor2img (uncertainty, uncertainty_cv);
129
144
130
145
// Set the size
131
146
cv::Size s_orig (width_orig, height_orig);
132
147
// Resize the input image back to the original size
133
148
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
134
149
cv::resize (prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
150
+ cv::resize (uncertainty_cv, uncertainty_cv, s_orig, cv::INTER_LINEAR);
135
151
// Generate color label image
136
152
cv::Mat color_label;
137
153
label_to_color (label, color_label);
@@ -140,8 +156,9 @@ PyTorchENetROS::inference(cv::Mat & input_img)
140
156
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
141
157
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
142
158
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , prob_cv).toImageMsg ();
159
+ sensor_msgs::ImagePtr uncertainty_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , uncertainty_cv).toImageMsg ();
143
160
144
- return std::forward_as_tuple (label_msg, color_label_msg, prob_msg);
161
+ return std::forward_as_tuple (label_msg, color_label_msg, prob_msg, uncertainty_msg );
145
162
}
146
163
147
164
/*
0 commit comments