9
9
PyTorchENetROS::PyTorchENetROS (ros::NodeHandle & nh)
10
10
: it_(nh), nh_(nh)
11
11
{
12
- // nh_ = ros::NodeHandle(nh);
13
- // it_ = image_transport::ImageTransport(nh_);
14
- // ROS_INFO("[PyTorchENetROS] Constructor");
15
-
16
12
sub_image_ = it_.subscribe (" image" , 1 , &PyTorchENetROS::image_callback, this );
17
13
pub_label_image_ = it_.advertise (" label" , 1 );
18
14
pub_color_image_ = it_.advertise (" color_label" , 1 );
@@ -67,7 +63,6 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
67
63
{
68
64
ROS_INFO (" [PyTorchENetROS image_inference_srv_callback] Start" );
69
65
70
-
71
66
// Convert the image message to a cv_bridge object
72
67
cv_bridge::CvImagePtr cv_ptr = msg_to_cv_bridge (req.img );
73
68
@@ -93,28 +88,30 @@ PyTorchENetROS::inference(cv::Mat & input_img)
93
88
int height_orig = input_img.size ().height ;
94
89
int width_orig = input_img.size ().width ;
95
90
96
- cv::Size s (480 , 264 );
91
+ cv::Size s (480 , 256 );
97
92
// Resize the input image
98
93
cv::resize (input_img, input_img, s);
99
94
100
- // ROS_INFO("[PyTorchENetROS inference] Start");
101
95
at::Tensor input_tensor;
102
96
pt_wrapper_.img2tensor (input_img, input_tensor);
103
97
98
+ // Normalize from [0, 255] -> [0, 1]
99
+ input_tensor /= 255.0 ;
100
+ // z-normalization
101
+ std::vector<float > mean_vec{0.485 , 0.456 , 0.406 };
102
+ std::vector<float > std_vec{0.229 , 0.224 , 0.225 };
103
+ for (int i = 0 ; i < mean_vec.size (); i++) {
104
+ input_tensor[0 ][i] = (input_tensor[0 ][i] - mean_vec[i]) / std_vec[i];
105
+ }
104
106
std::cout << input_tensor.sizes () << std::endl;
107
+
105
108
// Execute the model and turn its output into a tensor.
106
- // at::Tensor output = module->forward({input_tensor}).toTensor();
107
- // ROS_INFO("[PyTorchENetROS inference] get_output");
108
109
at::Tensor output = pt_wrapper_.get_output (input_tensor);
109
110
// Calculate argmax to get a label on each pixel
110
- // at::Tensor output_args = at::argmax(output, 1).to(torch::kCPU).to(at::kByte);
111
- // ROS_INFO("[PyTorchENetROS inference] get_argmax");
112
111
at::Tensor output_args = pt_wrapper_.get_argmax (output);
113
112
114
113
// Convert to OpenCV
115
- // cv::Mat mat(height, width, CV_8U, output_args[0]. template data<uint8_t>());
116
114
cv::Mat label;
117
- // ROS_INFO("[PyTorchENetROS inference] tensor2img");
118
115
pt_wrapper_.tensor2img (output_args[0 ], label);
119
116
120
117
// Set the size
@@ -123,18 +120,12 @@ PyTorchENetROS::inference(cv::Mat & input_img)
123
120
cv::resize (label, label, s_orig, cv::INTER_NEAREST);
124
121
125
122
cv::Mat color_label;
126
- // cv::applyColorMap(mat, color_label, cv::COLORMAP_JET);
127
123
label_to_color (label, color_label);
128
124
129
125
// Generate an image message
130
- // ROS_INFO("[PyTorchENetROS inference] cv_bridge to image msg");
131
126
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage (std_msgs::Header (), " mono8" , label).toImageMsg ();
132
127
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage (std_msgs::Header (), " rgb8" , color_label).toImageMsg ();
133
128
134
-
135
- // sensor_msgs::ImagePtr color_msg = cv_bridge::CvImage(std_msgs::Header(), "bgr8", image).toImageMsg();
136
-
137
- // ROS_INFO("[PyTorchENetROS inference] Publish");
138
129
return std::forward_as_tuple (label_msg, color_label_msg);
139
130
}
140
131
@@ -155,11 +146,11 @@ cv_bridge::CvImagePtr
155
146
PyTorchENetROS::msg_to_cv_bridge (sensor_msgs::ImageConstPtr msg)
156
147
{
157
148
cv_bridge::CvImagePtr cv_ptr;
149
+
158
150
// Convert the image message to a cv_bridge object
159
151
try
160
152
{
161
153
cv_ptr = cv_bridge::toCvCopy (msg, sensor_msgs::image_encodings::BGR8);
162
- // ROS_INFO("[PyTorchENetROS image_callback] Convert to cv_bridge object");
163
154
}
164
155
catch (cv_bridge::Exception& e)
165
156
{
@@ -177,11 +168,11 @@ cv_bridge::CvImagePtr
177
168
PyTorchENetROS::msg_to_cv_bridge (sensor_msgs::Image msg)
178
169
{
179
170
cv_bridge::CvImagePtr cv_ptr;
171
+
180
172
// Convert the image message to a cv_bridge object
181
173
try
182
174
{
183
175
cv_ptr = cv_bridge::toCvCopy (msg, sensor_msgs::image_encodings::BGR8);
184
- // ROS_INFO("[PyTorchENetROS image_callback] Convert to cv_bridge object");
185
176
}
186
177
catch (cv_bridge::Exception& e)
187
178
{
0 commit comments