In [1]:
%load_ext autoreload
%autoreload 2

import os
import time
import numpy as np

from pydrake.common import FindResourceOrThrow, RandomGenerator
from pydrake.geometry import Box, Role, SceneGraph
from pydrake.math import RigidTransform
from pydrake.multibody.parsing import LoadModelDirectives, Parser, ProcessModelDirectives
from pydrake.multibody.plant import CoulombFriction, AddMultibodyPlantSceneGraph, MultibodyPlant
from pydrake.multibody.tree import RevoluteJoint, SpatialInertia, UnitInertia
from pydrake.systems.analysis import Simulator
from pydrake.systems.framework import DiagramBuilder
from pydrake.systems.meshcat_visualizer import ConnectMeshcatVisualizer
from pydrake.systems.primitives import TrajectorySource
from pydrake.systems.rendering import MultibodyPositionToGeometryPose
from pydrake.trajectories import PiecewisePolynomial

from pydrake.planning.common_robotics_utilities import (
    MakeKinematicLinearRRTNearestNeighborsFunction,
    MakeKinematicLinearBiRRTNearestNeighborsFunction,
    MakeRRTTimeoutTerminationFunction,
    MakeBiRRTTimeoutTerminationFunction,
    PropagatedState,
    RRTPlanSinglePath,
    BiRRTPlanSinglePath,
    SimpleRRTPlannerState)

from meshcat import Visualizer

In [2]:
# Setup meshcat
from meshcat.servers.zmqserver import start_zmq_server_as_subprocess
proc, zmq_url, web_url = start_zmq_server_as_subprocess(server_args=[])
vis = Visualizer(zmq_url=zmq_url)

# Sporadically need to run `pkill -f meshcat`

You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/


# Visualize Model

In [3]:
model_file = "drake/planning/models/planar_iiwa_dense_collision_welded_gripper.yaml"
# model_file = "drake/planning/models/planar_iiwa_simple_collision_welded_gripper.yaml"

viz_role = Role.kIllustration
# viz_role = Role.kProximity

In [4]:
vis.delete()
display(vis.jupyter_cell())

builder = DiagramBuilder()
plant, scene_graph = AddMultibodyPlantSceneGraph(builder, time_step=0.0)
parser = Parser(plant)
parser.package_map().Add( "wsg_50_description", os.path.dirname(FindResourceOrThrow(
            "drake/manipulation/models/wsg_50_description/package.xml")))

table_dim = np.array([4, 4, 0.2])
table = plant.AddRigidBody("table", SpatialInertia(
        mass=1.0, p_PScm_E=np.array([0., 0., 0.]), G_SP_E=UnitInertia(1.0, 1.0, 1.0)))
plant.WeldFrames(plant.world_frame(), table.body_frame(), RigidTransform(p=np.array([0, 0, -table_dim[2]/2])))
plant.RegisterVisualGeometry(table, RigidTransform(), Box(*table_dim), "table_vis",
                             np.array([0.5, 0.5, 0.5, 1.]))
plant.RegisterCollisionGeometry(table, RigidTransform(), Box(*table_dim), "table_collision",
                                CoulombFriction(0.9, 0.8))

directives_file = FindResourceOrThrow(model_file)
directives = LoadModelDirectives(directives_file)
models = ProcessModelDirectives(directives, plant, parser)
[iiwa, wsg, shelves] = models

plant.Finalize()

visualizer = ConnectMeshcatVisualizer(builder, scene_graph, zmq_url=zmq_url,
                                      delete_prefix_on_load=False, role=viz_role)
diagram = builder.Build()

q0 = [-0.2, -1.2, 1.6]
index = 0
for joint_index in plant.GetJointIndices(iiwa.model_instance):
    joint = plant.get_mutable_joint(joint_index)
    if isinstance(joint, RevoluteJoint):
        joint.set_default_angle(q0[index])
        index += 1

visualizer.load()
context = diagram.CreateDefaultContext()
plant_context = plant.GetMyContextFromRoot(context)
sg_context = scene_graph.GetMyContextFromRoot(context)
diagram.Publish(context)

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/
Connected to meshcat-server.


# Visualize Trajectory

In [19]:
def visualize_trajectory(traj):
    builder = DiagramBuilder()
    
    scene_graph = builder.AddSystem(SceneGraph())
    plant = MultibodyPlant(time_step=0.0)
    plant.RegisterAsSourceForSceneGraph(scene_graph)
    parser = Parser(plant)
    parser.package_map().Add( "wsg_50_description", os.path.dirname(FindResourceOrThrow(
                "drake/manipulation/models/wsg_50_description/package.xml")))

    table_dim = np.array([4, 4, 0.2])
    table = plant.AddRigidBody("table", SpatialInertia(
            mass=1.0, p_PScm_E=np.array([0., 0., 0.]), G_SP_E=UnitInertia(1.0, 1.0, 1.0)))
    plant.WeldFrames(plant.world_frame(), table.body_frame(), RigidTransform(p=np.array([0, 0, -table_dim[2]/2])))
    plant.RegisterVisualGeometry(table, RigidTransform(), Box(*table_dim), "table_vis",
                                 np.array([0.5, 0.5, 0.5, 1.]))
    plant.RegisterCollisionGeometry(table, RigidTransform(), Box(*table_dim), "table_collision",
                                    CoulombFriction(0.9, 0.8))
    
    directives_file = FindResourceOrThrow(model_file)
    directives = LoadModelDirectives(directives_file)
    models = ProcessModelDirectives(directives, plant, parser)
    
    plant.Finalize()

    to_pose = builder.AddSystem(MultibodyPositionToGeometryPose(plant))
    builder.Connect(to_pose.get_output_port(), scene_graph.get_source_pose_port(plant.get_source_id()))

    traj_system = builder.AddSystem(TrajectorySource(traj))
    builder.Connect(traj_system.get_output_port(), to_pose.get_input_port())
    
    meshcat = ConnectMeshcatVisualizer(builder, scene_graph, zmq_url=zmq_url,
                                       delete_prefix_on_load=False, role=viz_role)

    vis_diagram = builder.Build()
    simulator = Simulator(vis_diagram)
    meshcat.start_recording()
    simulator.AdvanceTo(traj.end_time())
    meshcat.publish_recording()

q_test = [0.4, 1.8, -0.5]
traj = PiecewisePolynomial.FirstOrderHold([0, 3], np.array([[-0.2, -1.2, 1.6], q_test]).T)
vis.delete()
display(vis.jupyter_cell())
visualize_trajectory(traj)

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/
Connected to meshcat-server.


# Use RRT to Plan Trajectory

In [76]:
q_start = np.array([0.5, -1.2, 0])
q_end = np.array([0.8, -1.6, -0.9])

goal_bias = 0.05
step_size = np.pi/16
collision_step_size = step_size/4
solve_timeout = 100

plant_context_planning = plant.GetMyContextFromRoot(context)


def sampling_cspace():
    if np.random.rand() < goal_bias:
        return q_end
    # we have no joint limits imposed. Thus each joint can move between 0 and 2*pi
    return np.array([np.random.rand()*2*np.pi for _ in range(3)])

def check_goal_fn(q):
    return np.linalg.norm(q_end - q) < 1e-6

def difference(q_target, q_source):
    return (q_target - q_source + np.pi) % (2*np.pi) - np.pi

def distance_fn(q_1, q_2):
    diffs = np.array(list(map(lambda q_i1, q_i2: difference(q_i1, q_i2) ,q_1, q_2)))
    
    return np.sqrt(np.sum(np.square(diffs)))

def check_edge_validity_fn(start, end):
    #use first order interpolation to check collisions along a path between start and end
    def checkEdgeCollisionFree(start, end, stepsize):
        num_steps = np.ceil(distance_fn(start, end)/stepsize)

        for step in range(int(num_steps)+1):
            interpolation_ratio = step / num_steps
            interpolated_point = start + np.round(interpolation_ratio*(end-start))
            
            #check for collisions
            plant.SetPositions(plant_context_planning, interpolated_point)
            query_object = plant.get_geometry_query_input_port().Eval(plant_context_planning)
            
            if query_object.HasCollisions():
                return False

        return True

    return (checkEdgeCollisionFree(start, end, collision_step_size/2)
            and checkEdgeCollisionFree(end, start, collision_step_size/2))

def extend_fn(nearest, sample, is_start_tree = None):
    extend = None
    extend_dist = distance_fn(nearest, sample)
    
    if extend_dist <= step_size:
        extend = sample
    else:
        extend = nearest + step_size/extend_dist * (sample - nearest)
    
    extend = difference(extend, nearest) + nearest #ensure that shortest rotation path will be taken
    
    if not check_edge_validity_fn(nearest, extend):
        return []
    
    return [PropagatedState(state=extend, relative_parent_index=-1)]
    
    
    

rrt_tree = [SimpleRRTPlannerState(q_start)]


nearest_neighbor_fn = MakeKinematicLinearRRTNearestNeighborsFunction(distance_fn=distance_fn, use_parallel = False)

termination_fn = MakeRRTTimeoutTerminationFunction(solve_timeout)

single_result = RRTPlanSinglePath(
    tree=rrt_tree, sampling_fn=sampling_cspace,
    nearest_neighbor_fn =nearest_neighbor_fn,
    forward_propagation_fn=extend_fn,
    state_added_callback_fn=None,
    check_goal_reached_fn=check_goal_fn, goal_reached_callback_fn=None,
    termination_check_fn=termination_fn)

path = single_result.Path()

print(path)

rrt_traj = PiecewisePolynomial.FirstOrderHold(list(range(len(path))), np.stack(path).T) 

vis.delete()
display(vis.jupyter_cell())
visualize_trajectory(rrt_traj)

KeyboardInterrupt: 

# Use Bi-RRT to Plan Trajectory

In [89]:
# Reuse single RRT functions, but need to add connect
seed = 0
step_size = np.pi/32
collision_step_size = step_size/128

def connect_fn(nearest, sample, is_start_tree):
    total_dist = distance_fn(nearest, sample)
    total_steps = int(np.ceil(total_dist / step_size))        
        
    propagated_states = []
    parent_offset = -1
    current = nearest
    for steps in range(total_steps):
        current_target = None
        target_dist = distance_fn(current, sample)
        if target_dist > step_size:
            #interpolate
            current_target = current + step_size/target_dist * (sample - current)
            
        elif target_dist < 1e-6:
            break
        else:
            current_target = sample
        
        current_target = difference(current_target, current) + current #ensure that shortest rotation path will be taken

        if not check_edge_validity_fn(current, current_target):
            return propagated_states
    
                    
        propagated_states.append(PropagatedState(state=current_target, relative_parent_index=parent_offset))
        parent_offset += 1
        current = current_target

    return propagated_states
        
        
def states_connected_fn(source, target, is_start_tree):
    return np.linalg.norm(source - target) < 1e-6
        
start_tree = [SimpleRRTPlannerState(q_start)]
end_tree = [SimpleRRTPlannerState(q_end)]


nearest_neighbor_fn = MakeKinematicLinearBiRRTNearestNeighborsFunction(distance_fn=distance_fn, use_parallel = False)

termination_fn = MakeBiRRTTimeoutTerminationFunction(solve_timeout)

extend_result = BiRRTPlanSinglePath(
    start_tree=start_tree, goal_tree=end_tree,
    state_sampling_fn=sampling_cspace,
    nearest_neighbor_fn=nearest_neighbor_fn, propagation_fn=extend_fn,
    state_added_callback_fn=None,
    states_connected_fn=states_connected_fn,
    goal_bridge_callback_fn=None,
    tree_sampling_bias=0.5, p_switch_tree=0.25,
    termination_check_fn=termination_fn, rng=RandomGenerator(seed))

connect_result = BiRRTPlanSinglePath(
    start_tree=start_tree, goal_tree=end_tree,
    state_sampling_fn=sampling_cspace,
    nearest_neighbor_fn=nearest_neighbor_fn, propagation_fn=connect_fn,
    state_added_callback_fn=None,
    states_connected_fn=states_connected_fn,
    goal_bridge_callback_fn=None,
    tree_sampling_bias=0.5, p_switch_tree=0.25,
    termination_check_fn=termination_fn, rng=RandomGenerator(seed))

extend_path = extend_result.Path()
connect_path = connect_result.Path()

  interpolation_ratio = step / num_steps


[array([ 0.5, -1.2,  0. ]), array([ 0.70749379, -0.93853702,  0.04721764]), array([ 0.70694375, -0.8729474 ,  0.1427304 ]), array([ 0.70694375, -0.8729474 ,  0.1427304 ]), array([ 0.70963084, -0.88183402,  0.10206186]), array([ 0.71595487, -0.90274847,  0.00634939]), array([ 0.70195977, -0.94528196, -0.08101951]), array([ 0.73200878, -0.9278626 , -0.17284494]), array([ 0.71687251, -0.94470455, -0.26837256]), array([ 0.75838831, -0.88490495, -0.33424162]), array([ 0.77417326, -0.87390973, -0.43051325]), array([ 0.79323201, -1.33579998, -0.75655865]), array([ 0.8, -1.6, -0.9])]
[array([ 0.5, -1.2,  0. ]), array([ 0.52860673, -1.23814231, -0.08582019]), array([ 0.55721346, -1.27628461, -0.17164038]), array([ 0.63699159, -1.11716413, -0.07145177]), array([ 0.71419189, -0.96318524,  0.02549948]), array([ 0.78927332, -0.81343251,  0.11978978]), array([ 0.86212828, -0.66812056,  0.21128399]), array([ 0.932686  , -0.52739054,  0.29989322]), array([ 1.00091013, -0.39131494,  0.38557185]), array

In [92]:
birrt_traj = PiecewisePolynomial.FirstOrderHold(list(range(len(connect_path))), np.stack(connect_path).T) 
#birrt_traj = PiecewisePolynomial.FirstOrderHold(list(range(len(extend_path))), np.stack(extend_path).T) 

vis.delete()
display(vis.jupyter_cell())
visualize_trajectory(birrt_traj)

Connecting to meshcat-server at zmq_url=tcp://127.0.0.1:6001...
You can open the visualizer by visiting the following URL:
http://127.0.0.1:7010/static/
Connected to meshcat-server.
