Skip to content

Commit 2879cbf

Browse files
[update] Add probability module
1 parent 1b14c50 commit 2879cbf

File tree

4 files changed

+35
-12
lines changed

4 files changed

+35
-12
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ private :
2525
bool import_module(const std::string filename);
2626
void img2tensor(cv::Mat & img, at::Tensor & tensor, const bool use_gpu = true);
2727
void tensor2img(at::Tensor tensor, cv::Mat & img);
28-
at::Tensor get_output(at::Tensor input_tensor);
28+
// at::Tensor get_output(at::Tensor input_tensor);
29+
std::tuple<at::Tensor, at::Tensor> get_output(at::Tensor input_tensor);
2930
at::Tensor get_argmax(at::Tensor input_tensor);
3031
};
3132
//}

launch/pytorch_enet_ros.launch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
<param name="colormap" value="$(find pytorch_enet_ros)/images/camvid12.png" />
1010
-->
1111
<remap from="~image" to="$(arg image)" />
12-
<param name="model_file" value="$(find pytorch_enet_ros)/models/espdnet_ue_uest_trav.pt" />
12+
<param name="model_file" value="$(find pytorch_enet_ros)/models/espdnet_ue_trav_20210115-151110.pt" />
1313
<!--
1414
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_greenhouse.pt" />
1515
-->

src/pytorch_cpp_wrapper.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
*/
66

77

8+
#include <torch/torch.h>
89
#include "pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
910
#include <torch/script.h> // One-stop header.
1011
#include <torch/data/transforms/tensor.h> // One-stop header.
@@ -52,7 +53,7 @@ PyTorchCppWrapper::import_module(const std::string filename)
5253
return true;
5354
}
5455
catch (const c10::Error& e) {
55-
std::cerr << "error loading the model\n";
56+
std::cerr << e.what();
5657
return false;
5758
}
5859
}
@@ -78,16 +79,18 @@ PyTorchCppWrapper::img2tensor(cv::Mat & img, at::Tensor & tensor, const bool use
7879
void
7980
PyTorchCppWrapper::tensor2img(at::Tensor tensor, cv::Mat & img)
8081
{
81-
std::cout << tensor.sizes() << std::endl;
8282
// Get the size of the input image
8383
int height = tensor.sizes()[0];
8484
int width = tensor.sizes()[1];
8585

86+
tensor = tensor.to(torch::kCPU);
87+
8688
// Convert to OpenCV
8789
img = cv::Mat(height, width, CV_8U, tensor. template data<uint8_t>());
8890
}
8991

90-
at::Tensor
92+
//at::Tensor
93+
std::tuple<at::Tensor, at::Tensor>
9194
PyTorchCppWrapper::get_output(at::Tensor input_tensor)
9295
{
9396
// Execute the model and turn its output into a tensor.
@@ -97,8 +100,17 @@ PyTorchCppWrapper::get_output(at::Tensor input_tensor)
97100

98101
at::Tensor output1 = outputs->elements()[0].toTensor();
99102
at::Tensor output2 = outputs->elements()[1].toTensor();
103+
at::Tensor prob = outputs->elements()[2].toTensor();
104+
105+
// Divide probability by c
106+
prob = torch::sigmoid(prob) / 0.3;
107+
// Limit the values in range [0, 1]
108+
prob = at::clamp(prob, 0.0, 1.0);
109+
110+
// return output1 + 0.5 * output2;
111+
at::Tensor segmentation = output1 + 0.5 * output2;
100112

101-
return output1 + 0.5 * output2;
113+
return std::forward_as_tuple(segmentation, prob);
102114
}
103115

104116
at::Tensor

src/pytorch_enet_ros.cpp

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
1818
std::string filename;
1919
nh_.param<std::string>("model_file", filename, "");
2020
if(!pt_wrapper_.import_module(filename)) {
21-
ROS_ERROR("Failed to import the model file [%s]", filename.c_str());
21+
ROS_ERROR("Failed to import the model file [%s]", filename.c_str());
2222
ros::shutdown();
2323
}
2424

@@ -44,6 +44,7 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
4444
// Run inference
4545
sensor_msgs::ImagePtr label_msg;
4646
sensor_msgs::ImagePtr color_label_msg;
47+
sensor_msgs::ImagePtr prob_msg;
4748
std::tie(label_msg, color_label_msg) = inference(cv_ptr->image);
4849

4950
// Set header
@@ -69,6 +70,7 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
6970
// Run inference
7071
sensor_msgs::ImagePtr label_msg;
7172
sensor_msgs::ImagePtr color_label_msg;
73+
sensor_msgs::ImagePtr prob_msg;
7274
std::tie(label_msg, color_label_msg) = inference(cv_ptr->image);
7375

7476
res.label_img = *label_msg;
@@ -103,27 +105,35 @@ PyTorchENetROS::inference(cv::Mat & input_img)
103105
for(int i = 0; i < mean_vec.size(); i++) {
104106
input_tensor[0][i] = (input_tensor[0][i] - mean_vec[i]) / std_vec[i];
105107
}
106-
std::cout << input_tensor.sizes() << std::endl;
108+
// std::cout << input_tensor.sizes() << std::endl;
107109

108110
// Execute the model and turn its output into a tensor.
109-
at::Tensor output = pt_wrapper_.get_output(input_tensor);
111+
at::Tensor segmentation;
112+
at::Tensor prob;
113+
std::tie(segmentation, prob) = pt_wrapper_.get_output(input_tensor);
114+
// at::Tensor output = pt_wrapper_.get_output(input_tensor);
110115
// Calculate argmax to get a label on each pixel
111-
at::Tensor output_args = pt_wrapper_.get_argmax(output);
116+
// at::Tensor output_args = pt_wrapper_.get_argmax(output);
117+
118+
at::Tensor output_args = pt_wrapper_.get_argmax(segmentation);
112119

113120
// Convert to OpenCV
114121
cv::Mat label;
122+
cv::Mat prob_cv;
115123
pt_wrapper_.tensor2img(output_args[0], label);
124+
pt_wrapper_.tensor2img((prob[0][0]*255).to(torch::kByte), prob_cv);
116125

117126
// Set the size
118127
cv::Size s_orig(width_orig, height_orig);
119128
// Resize the input image back to the original size
120129
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
121-
130+
cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
122131
cv::Mat color_label;
123132
label_to_color(label, color_label);
124133

125134
// Generate an image message
126-
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
135+
// sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
136+
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
127137
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
128138

129139
return std::forward_as_tuple(label_msg, color_label_msg);

0 commit comments

Comments
 (0)