Skip to content

Commit ac6af52

Browse files
Merge master
2 parents a201e56 + 3623af0 commit ac6af52

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

launch/pytorch_enet_ros.launch

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
<node pkg="pytorch_enet_ros" type="pytorch_enet_ros_node" name="pytorch_enet_ros_node" output="screen">
55
<!--
66
<remap from="~image" to="/kinect2/qhd/image_color_rect" />
7-
-->
8-
<!--
9-
<remap from="~image" to="/camera/rgb/image_rect_color" />
7+
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_camvid.pt" />
8+
<param name="colormap" value="$(find pytorch_enet_ros)/images/camvid12.png" />
109
-->
1110
<remap from="~image" to="/camera/rgb/image_rect_color" />
12-
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_c2-501.pt" />
11+
<param name="model_file" value="$(find pytorch_enet_ros)/models/ENet_scenes_train_c2-501_2_200.pt" />
1312
<param name="colormap" value="$(find pytorch_enet_ros)/images/greenhouse4.png" />
14-
<param name="model_name" value="camvid" />
13+
<param name="model_name" value="greenhouse" />
1514
</node>
1615
</launch>

src/pytorch_cpp_wrapper.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ PyTorchCppWrapper::import_module(const std::string filename)
2929
try {
3030
// Deserialize the ScriptModule from a file using torch::jit::load().
3131
module_ = torch::jit::load(filename);
32+
// Set evaluation mode
33+
module_->eval();
34+
std::cout << module_->is_training() << std::endl;
35+
3236
std::cout << "Import succeeded" << std::endl;
3337
return true;
3438
}

src/pytorch_enet_ros.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ PyTorchENetROS::inference(cv::Mat & input_img)
9393
int height_orig = input_img.size().height;
9494
int width_orig = input_img.size().width;
9595

96-
cv::Size s(480, 264);
96+
// cv::Size s(320, 240);
9797
// Resize the input image
98-
cv::resize(input_img, input_img, s);
98+
// cv::resize(input_img, input_img, s);
9999

100100
// ROS_INFO("[PyTorchENetROS inference] Start");
101101
at::Tensor input_tensor;
@@ -120,7 +120,7 @@ PyTorchENetROS::inference(cv::Mat & input_img)
120120
// Set the size
121121
cv::Size s_orig(width_orig, height_orig);
122122
// Resize the input image back to the original size
123-
cv::resize(label, label, s_orig, cv::INTER_NEAREST);
123+
// cv::resize(label, label, s_orig, cv::INTER_NEAREST);
124124

125125
cv::Mat color_label;
126126
// cv::applyColorMap(mat, color_label, cv::COLORMAP_JET);

0 commit comments

Comments
 (0)