Skip to content

Commit 707d18e

Browse files
[update] Change program to output probability map
1 parent 2879cbf commit 707d18e

File tree

3 files changed

+21
-12
lines changed

3 files changed

+21
-12
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ class PyTorchCppWrapper {
1515
private :
1616
// std::shared_ptr<torch::jit::script::Module> module_;
1717
torch::jit::script::Module module_;
18+
// c = P(s|y=1) in PU learning, calculated during training
19+
float c_;
1820
// torch::data::transforms::Normalize<at::Tensor> normalizer_;
1921

2022
public:

include/pytorch_enet_ros/pytorch_enet_ros.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
#include <opencv2/opencv.hpp>
1313
#include<image_transport/image_transport.h>
1414
#include<cv_bridge/cv_bridge.h>
15-
#include<semantic_segmentation_srvs/GetLabelImage.h>
15+
//#include<semantic_segmentation_srvs/GetLabelImage.h>
16+
#include<semantic_segmentation_srvs/GetLabelAndProbability.h>
1617

1718
#include"pytorch_cpp_wrapper/pytorch_cpp_wrapper.h"
1819

@@ -31,6 +32,7 @@ class PyTorchENetROS {
3132
image_transport::Subscriber sub_image_;
3233
image_transport::Publisher pub_label_image_;
3334
image_transport::Publisher pub_color_image_;
35+
image_transport::Publisher pub_prob_image_;
3436

3537
PyTorchCppWrapper pt_wrapper_;
3638

@@ -40,9 +42,9 @@ class PyTorchENetROS {
4042
PyTorchENetROS(ros::NodeHandle & nh);
4143

4244
void image_callback(const sensor_msgs::ImageConstPtr& msg);
43-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
44-
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelImage::Request & req,
45-
semantic_segmentation_srvs::GetLabelImage::Response & res);
45+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr> inference(cv::Mat & input_image);
46+
bool image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req,
47+
semantic_segmentation_srvs::GetLabelAndProbability::Response & res);
4648
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::ImageConstPtr msg);
4749
cv_bridge::CvImagePtr msg_to_cv_bridge(sensor_msgs::Image msg);
4850
void label_to_color(cv::Mat& label, cv::Mat& color_label);

src/pytorch_enet_ros.cpp

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ PyTorchENetROS::PyTorchENetROS(ros::NodeHandle & nh)
1212
sub_image_ = it_.subscribe("image", 1, &PyTorchENetROS::image_callback, this);
1313
pub_label_image_ = it_.advertise("label", 1);
1414
pub_color_image_ = it_.advertise("color_label", 1);
15+
pub_prob_image_ = it_.advertise("prob", 1);
1516
get_label_image_server_ = nh_.advertiseService("get_label_image", &PyTorchENetROS::image_inference_srv_callback, this);
1617

1718
// Import the model
@@ -45,22 +46,24 @@ PyTorchENetROS::image_callback(const sensor_msgs::ImageConstPtr& msg)
4546
sensor_msgs::ImagePtr label_msg;
4647
sensor_msgs::ImagePtr color_label_msg;
4748
sensor_msgs::ImagePtr prob_msg;
48-
std::tie(label_msg, color_label_msg) = inference(cv_ptr->image);
49+
std::tie(label_msg, color_label_msg, prob_msg) = inference(cv_ptr->image);
4950

5051
// Set header
5152
label_msg->header = msg->header;
5253
color_label_msg->header = msg->header;
54+
prob_msg->header = msg->header;
5355

5456
pub_label_image_.publish(label_msg);
5557
pub_color_image_.publish(color_label_msg);
58+
pub_prob_image_.publish(prob_msg);
5659
}
5760

5861
/*
5962
* image_inference_srv_callback : Callback for the service
6063
*/
6164
bool
62-
PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabelImage::Request & req,
63-
semantic_segmentation_srvs::GetLabelImage::Response & res)
65+
PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabelAndProbability::Request & req,
66+
semantic_segmentation_srvs::GetLabelAndProbability::Response & res)
6467
{
6568
ROS_INFO("[PyTorchENetROS image_inference_srv_callback] Start");
6669

@@ -71,18 +74,19 @@ PyTorchENetROS::image_inference_srv_callback(semantic_segmentation_srvs::GetLabe
7174
sensor_msgs::ImagePtr label_msg;
7275
sensor_msgs::ImagePtr color_label_msg;
7376
sensor_msgs::ImagePtr prob_msg;
74-
std::tie(label_msg, color_label_msg) = inference(cv_ptr->image);
77+
std::tie(label_msg, color_label_msg, prob_msg) = inference(cv_ptr->image);
7578

7679
res.label_img = *label_msg;
7780
res.colorlabel_img = *color_label_msg;
81+
res.prob_img = *prob_msg;
7882

7983
return true;
8084
}
8185

8286
/*
8387
* inference : Forward the given input image through the network and return the inference result
8488
*/
85-
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
89+
std::tuple<sensor_msgs::ImagePtr, sensor_msgs::ImagePtr, sensor_msgs::ImagePtr>
8690
PyTorchENetROS::inference(cv::Mat & input_img)
8791
{
8892

@@ -128,15 +132,16 @@ PyTorchENetROS::inference(cv::Mat & input_img)
128132
// Resize the input image back to the original size
129133
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
130134
cv::resize(prob_cv, prob_cv, s_orig, cv::INTER_LINEAR);
135+
// Generate color label image
131136
cv::Mat color_label;
132137
label_to_color(label, color_label);
133138

134139
// Generate an image message
135-
// sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
136-
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
140+
sensor_msgs::ImagePtr label_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", label).toImageMsg();
137141
sensor_msgs::ImagePtr color_label_msg = cv_bridge::CvImage(std_msgs::Header(), "rgb8", color_label).toImageMsg();
142+
sensor_msgs::ImagePtr prob_msg = cv_bridge::CvImage(std_msgs::Header(), "mono8", prob_cv).toImageMsg();
138143

139-
return std::forward_as_tuple(label_msg, color_label_msg);
144+
return std::forward_as_tuple(label_msg, color_label_msg, prob_msg);
140145
}
141146

142147
/*

0 commit comments

Comments
 (0)