Skip to content

Commit 279a24f

Browse files
WIP: Implement uncertainty estimation
1 parent 0f64d46 commit 279a24f

File tree

5 files changed

+35
-2
lines changed

5 files changed

+35
-2
lines changed

include/pytorch_cpp_wrapper/pytorch_cpp_wrapper_base.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,14 @@ protected :
4949
* @param[out] tensor that has index of max value in each element
5050
*/
5151
at::Tensor get_argmax(at::Tensor input_tensor);
52+
53+
/**
54+
* @brief Take element-wise entropy
55+
* @param[in] tensor
56+
* @param[out] tensor that has index of max value in each element
57+
*/
58+
at::Tensor get_entropy(at::Tensor input_tensor);
59+
5260
};
5361
//}
5462
#endif

launch/pytorch_enet_ros.launch

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
<?xml version="1.0"?>
22
<launch>
33
<arg name="image" default="/camera/rgb/image_rect_color" />
4-
<arg name="model_name" default="$(find pytorch_ros)/models/espdnet_ue_trav_path_20210518-221714.pt" />
4+
<arg name="model_name" default="$(find pytorch_ros)/models/espdnet_ue_trav_path_20210712-134315.pt" />
55

66
<node pkg="pytorch_ros" type="pytorch_seg_trav_path_node" name="pytorch_seg_trav_path_node" output="screen">
77
<remap from="~image" to="$(arg image)" />

script/visualizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,4 @@ def main():
8383

8484

8585
if __name__=='__main__':
86-
main()
86+
main()

src/pytorch_cpp_wrapper_base.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,21 @@ PyTorchCppWrapperBase::get_argmax(at::Tensor input_tensor)
105105

106106
return output;
107107
}
108+
109+
/**
110+
* @brief Take element-wise entropy
111+
* @param[in] tensor
112+
* @param[out] tensor that has index of max value in each element
113+
*/
114+
at::Tensor
115+
PyTorchCppWrapperBase::get_entropy(at::Tensor input_tensor)
116+
{
117+
input_tensor.to(torch::kCUDA);
118+
// Calculate the entropy at each pixel
119+
at::Tensor log_p = torch::log_softmax(input_tensor, /*dim=*/1);//at::argmax(input_tensor, 1).to(torch::kCPU).to(at::kByte);
120+
at::Tensor p = torch::log_softmax(input_tensor, /*dim=*/1);
121+
122+
at::Tensor entropy = torch::sum(p * log_p, /*dim=*/1);
123+
124+
return entropy;
125+
}

src/pytorch_seg_trav_path_ros.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,10 +133,17 @@ PyTorchSegTravPathROS::inference(cv::Mat & input_img)
133133
at::Tensor segmentation;
134134
at::Tensor prob;
135135
at::Tensor points;
136+
// segmentation: raw output for segmentation (before softmax)
137+
// prob: traversability
138+
// points: coordinates of the line points
136139
std::tie(segmentation, prob, points) = pt_wrapper_.get_output(input_tensor);
137140

141+
// Get class label map by taking argmax of 'segmentation'
138142
at::Tensor output_args = pt_wrapper_.get_argmax(segmentation);
139143

144+
// Uncertainty of segmentation
145+
at::Tensor uncertainty = pt_wrapper_.get_entropy(segmentation);
146+
140147
// Convert to OpenCV
141148
cv::Mat label;
142149
cv::Mat prob_cv;

0 commit comments

Comments
 (0)