Skip to content

Commit

Permalink
Creating a simple action server to be able to invoke Random Crawl fro…
Browse files Browse the repository at this point in the history
…m BT (#945)

added random crawl action to bt
  • Loading branch information
mhpanah committed Aug 2, 2019
1 parent 2350778 commit 675574f
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2018 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef NAV2_BEHAVIOR_TREE__RANDOM_CRAWL_ACTION_HPP_
#define NAV2_BEHAVIOR_TREE__RANDOM_CRAWL_ACTION_HPP_

#include <string>

#include "nav2_behavior_tree/bt_action_node.hpp"
#include "nav2_msgs/action/random_crawl.hpp"

namespace nav2_behavior_tree
{

class RandomCrawlAction : public BtActionNode<nav2_msgs::action::RandomCrawl>
{
public:
explicit RandomCrawlAction(const std::string & action_name)
: BtActionNode<nav2_msgs::action::RandomCrawl>(action_name)
{
}
};

} // namespace nav2_behavior_tree

#endif // NAV2_BEHAVIOR_TREE__RANDOM_CRAWL_ACTION_HPP_
2 changes: 2 additions & 0 deletions nav2_bt_navigator/src/navigate_to_pose_behavior_tree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "nav2_behavior_tree/is_stuck_condition.hpp"
#include "nav2_behavior_tree/rate_controller_node.hpp"
#include "nav2_behavior_tree/spin_action.hpp"
#include "nav2_behavior_tree/random_crawl_action.hpp"
#include "rclcpp/rclcpp.hpp"

using namespace std::chrono_literals;
Expand All @@ -40,6 +41,7 @@ NavigateToPoseBehaviorTree::NavigateToPoseBehaviorTree()
factory_.registerNodeType<nav2_behavior_tree::FollowPathAction>("FollowPath");
factory_.registerNodeType<nav2_behavior_tree::BackUpAction>("BackUp");
factory_.registerNodeType<nav2_behavior_tree::SpinAction>("Spin");
factory_.registerNodeType<nav2_behavior_tree::RandomCrawlAction>("RandomCrawl");

// Register our custom condition nodes
factory_.registerNodeType<nav2_behavior_tree::IsStuckCondition>("IsStuck");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import rclpy
import parameters

from rclpy.node import Node
from time import sleep
from keras.models import load_model

Expand Down
110 changes: 110 additions & 0 deletions nav2_experimental/nav2_rl/nav2_turtlebot3_rl/random_crawl_action.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright (c) 2019 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os.path
from ament_index_python.packages import get_package_share_directory
from turtlebot3_env import TurtlebotEnv
import numpy as np
import rclpy
import parameters

from rclpy.node import Node
from rclpy.action import ActionServer, CancelResponse, GoalResponse
from time import sleep
from keras.models import load_model

import time

from nav2_msgs.action import RandomCrawl
from rclpy.executors import MultiThreadedExecutor
from rclpy.callback_groups import ReentrantCallbackGroup
import threading


class RandomCrawlActionServer(Node):

def __init__(self):
super().__init__('action_server')
self.env = TurtlebotEnv()
self.action_size = self.env.action_space()
print(self.action_size)
self.state = self.env.reset()
self.observation_space = len(self.state)
self.state = np.reshape(self.state, [1, self.observation_space])
pkg_share_directory = get_package_share_directory('nav2_turtlebot3_rl')
path = os.path.join(pkg_share_directory, "saved_models/random_crawl_waffle.h5")
self.model = load_model(path)
q = self.model.predict(self.state)
self._goal_handle = None
self._goal_lock = threading.Lock()
self._action_server = ActionServer(
self,
RandomCrawl,
'RandomCrawl',
execute_callback=self.execute_callback,
handle_accepted_callback=self.handle_accepted_callback,
goal_callback=self.goal_callback,
cancel_callback=self.cancel_callback,
callback_group=ReentrantCallbackGroup())

def destroy(self):
self.env.cleanup()
self._action_server.destroy()
super().destroy_node()

def goal_callback(self, goal_request):
self.get_logger().info('Received goal request')
return GoalResponse.ACCEPT

def handle_accepted_callback(self, goal_handle):
with self._goal_lock:
# This server only allows one goal at a time
if self._goal_handle is not None and self._goal_handle.is_active:
self.get_logger().info('Aborting previous goal')
# Abort the existing goal
self._goal_handle.abort()
self._goal_handle = goal_handle

goal_handle.execute()

def cancel_callback(self, goal_handle):
self.get_logger().info('Received cancel request')
return CancelResponse.ACCEPT

def execute_callback(self, goal_handle):
while not goal_handle.is_cancel_requested and rclpy.ok():
q_values = self.model.predict(self.state)
action = np.argmax(q_values)
next_state, reward, terminal = self.env.step(action)
next_state = np.reshape(next_state, [1, self.observation_space])
self.state = next_state
sleep(parameters.LOOP_RATE)
goal_handle.succeed()
self.env.stop_action()
result = RandomCrawl.Result()
return result


def main(args=None):
rclpy.init(args=args)

action_server = RandomCrawlActionServer()
executor = MultiThreadedExecutor()
rclpy.spin(action_server, executor=executor)

action_server.destroy()
rclpy.shutdown()

if __name__ == '__main__':
main()
2 changes: 2 additions & 0 deletions nav2_experimental/nav2_rl/nav2_turtlebot3_rl/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
py_modules=[
'random_crawl_train',
'random_crawl',
'random_crawl_action',
'turtlebot3_env',
'dqn',
'parameters',
Expand Down Expand Up @@ -39,6 +40,7 @@
'console_scripts': [
'random_crawl_train = random_crawl_train:main',
'random_crawl = random_crawl:main',
'random_crawl_action = random_crawl_action:main',
],
},
)
42 changes: 25 additions & 17 deletions nav2_experimental/nav2_rl/nav2_turtlebot3_rl/turtlebot3_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import rclpy
from rclpy.node import Node
from rclpy.duration import Duration
from rclpy.qos import qos_profile_sensor_data
from rclpy.executors import SingleThreadedExecutor

import numpy as np
import math
Expand All @@ -31,9 +33,12 @@
from std_msgs.msg import String
from gazebo_msgs.srv import GetEntityState, SetEntityState


class TurtlebotEnv():
def __init__(self):
self.node_ = rclpy.create_node('turtlebot3_env')
self.executor = SingleThreadedExecutor()
self.executor.add_node(self.node_)
self.act = 0
self.done = False
self.actions = [[parameters.LINEAR_FWD_VELOCITY, parameters.ANGULAR_FWD_VELOCITY],
Expand All @@ -48,8 +53,9 @@ def __init__(self):
self.zero_div_tol = 0.01
self.range_min = 0.0

self.pub_cmd_vel = self.node_.create_publisher(Twist, 'cmd_vel')
self.sub_scan = self.node_.create_subscription(LaserScan, 'scan', self.scan_callback)
self.pub_cmd_vel = self.node_.create_publisher(Twist, 'cmd_vel', 1)
self.sub_scan = self.node_.create_subscription(LaserScan, 'scan', self.scan_callback,
qos_profile_sensor_data)

self.reset_simulation = self.node_.create_client(Empty, 'reset_simulation')
self.reset_world = self.node_.create_client(Empty, 'reset_world')
Expand All @@ -58,21 +64,21 @@ def __init__(self):
self.get_entity_state = self.node_.create_client(GetEntityState, 'get_entity_state')
self.set_entity_state = self.node_.create_client(SetEntityState, 'set_entity_state')
self.scan_msg_received = False
self.t = Thread(target=rclpy.spin, args=[self.node_])
self.t = Thread(target=self.executor.spin)
self.t.start()

def cleanup(self):
self.t.join()

def get_reward(self):
reward = 0
if self.collision == True:
if self.collision is True:
reward = -10
self.done = True
return reward, self.done
elif self.collision == False and self.act == 0:
elif self.collision is False and self.act is 0:
if abs(min(self.states_input)) >= self.zero_div_tol:
reward = 0.08 - (1/(min(self.states_input)**2))*0.005
reward = 0.08 - (1 / (min(self.states_input)**2)) * 0.005
else:
reward = -10
if reward > 0:
Expand All @@ -82,12 +88,11 @@ def get_reward(self):
bonous_discount_factor = 0.6
self.bonous_reward *= bonous_discount_factor
if abs(min(self.states_input)) >= self.zero_div_tol:
reward = 0.02 - (1/min(self.states_input))*0.005
reward = 0.02 - (1 / min(self.states_input)) * 0.005
else:
reward = -10
return reward, self.done


def scan_callback(self, LaserScan):
self.scan_msg_received = True
self.laser_scan_range = []
Expand All @@ -106,8 +111,9 @@ def scan_callback(self, LaserScan):
self.done = True
self.states_input = []
for i in range(8):
step = int(len(LaserScan.ranges)/8)
self.states_input.append(min(self.laser_scan_range[i*step:(i+1)*step], default=0))
step = int(len(LaserScan.ranges) / 8)
self.states_input.append(min(self.laser_scan_range[i * step:(i + 1) * step],
default=0))

def action_space(self):
return len(self.actions)
Expand All @@ -122,20 +128,22 @@ def step(self, action):
vel_cmd.angular.z = 0.0
get_reward = self.get_reward()
return self.states_input, get_reward[0], self.done

def check_collision(self):
if min(self.laser_scan_range) < self.range_min + self.collision_tol:
if min(self.laser_scan_range) < self.range_min + self.collision_tol:
print("Near collision detected... " + str(min(self.laser_scan_range)))
return True
return False

def reset(self):
self.scan_msg_received = False
def stop_action(self):
vel_cmd = Twist()
vel_cmd.linear.x = 0.0
vel_cmd.angular.z = 0.0
self.pub_cmd_vel.publish(vel_cmd)


def reset(self):
self.scan_msg_received = False
self.stop_action
while not self.reset_world.wait_for_service(timeout_sec=1.0):
print('Reset world service is not available...')
self.reset_world.call_async(Empty.Request())
Expand All @@ -151,4 +159,4 @@ def reset(self):
self.collision = False
self.done = False
self.bonous_reward = 0
return self.states_input
return self.states_input
3 changes: 2 additions & 1 deletion nav2_msgs/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ rosidl_generate_interfaces(${PROJECT_NAME}
"action/NavigateToPose.action"
"action/Spin.action"
"action/DummyRecovery.action"
"action/RandomCrawl.action"
DEPENDENCIES builtin_interfaces geometry_msgs std_msgs action_msgs
)

ament_export_dependencies(rosidl_default_runtime)

ament_package()
ament_package()
7 changes: 7 additions & 0 deletions nav2_msgs/action/RandomCrawl.action
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#goal definition
std_msgs/Empty target
---
#result definition
std_msgs/Empty result
---
#feedback

0 comments on commit 675574f

Please sign in to comment.