In [1]:
import numpy as np
from functools import partial
import ipywidgets as widgets
from IPython.display import display

In [2]:
#pydrake imports
from pydrake.all import RationalForwardKinematics
from pydrake.geometry.optimization import IrisOptions, HPolyhedron, Hyperellipsoid
from pydrake.solvers import MosekSolver, CommonSolverOption, SolverOptions
from pydrake.all import PiecewisePolynomial, InverseKinematics, Sphere, Rgba, RigidTransform, RotationMatrix, IrisInConfigurationSpace
import time
import pydrake
from pydrake.all import (SceneGraphCollisionChecker, 
                         StartMeshcat, 
                         RobotDiagramBuilder,
                         ProcessModelDirectives,
                         LoadModelDirectives,
                         MeshcatVisualizer,
                        DiagramBuilder,
                         MultibodyPlant,
                         AddMultibodyPlantSceneGraph,
                         PiecewisePolynomial,
                        Parser)
from pydrake.all import GeometrySet, CollisionFilterDeclaration
from scipy.special import comb
import matplotlib.pyplot as plt
import pydrake.multibody.rational as rational_forward_kinematics
from pydrake.all import RationalForwardKinematics

import time
from pydrake.geometry.optimization_dev import CspaceFreePath
from pydrake.all import Role, MeshcatVisualizerParams
from pydrake.all import InverseDynamicsController, TrajectorySource,StateInterpolatorWithDiscreteDerivative


In [3]:
from pydrake.examples import ManipulationStation


In [4]:
meshcat = StartMeshcat()


INFO:drake:Meshcat listening for connections at http://localhost:7000


In [5]:
from pydrake.all import LeafSystem
class ConvertSToQ(LeafSystem):
    def __init__(self, q_star, Ratfk):
        super().__init__()  
        self.s_port = self.DeclareVectorInputPort("s", q_star.shape[0])
        self.DeclareVectorOutputPort("q", q_star.shape[0], self.CalcOutput)
        self.q_star = q_star
        self.Ratfk = Ratfk
        
    def CalcOutput(self, context, output):
        s = self.s_port.Eval(context)
        q = self.Ratfk.ComputeQValue(s, self.q_star)
        output.set_value(q)

In [6]:
meshcat.Delete()
builder = DiagramBuilder()
iiwa_directive_file = "7_dof_iiwa_directive.yaml"
shelf_directive_file = "7_dof_iiwa_shelves_directive.yaml"
time_step = 0.002
# Build one plant to model the world
sim_plant, sim_scene_graph = AddMultibodyPlantSceneGraph(builder, time_step)
sim_parser = Parser(sim_plant)
iiwa_directive = LoadModelDirectives(iiwa_directive_file)
shelf_directive = LoadModelDirectives(shelf_directive_file)
models = []
models += ProcessModelDirectives(iiwa_directive, sim_plant, sim_parser)
models += ProcessModelDirectives(shelf_directive, sim_plant, sim_parser)
## Add any other models you want the robot to manipulate here.
sim_plant.Finalize()
meshcat_params = MeshcatVisualizerParams()
visualizer = MeshcatVisualizer.AddToBuilder(builder, sim_scene_graph, meshcat, meshcat_params)



controller_plant = MultibodyPlant(time_step)
ProcessModelDirectives(iiwa_directive, controller_plant, Parser(controller_plant))
controller_plant.Finalize()
Ratfk = RationalForwardKinematics(controller_plant)

num_positions = controller_plant.num_positions()

# Add an inverse dynamics controller which uses the controller_plant.
inv_dynamics_controller = builder.AddSystem(
    InverseDynamicsController(
        controller_plant,
        kp=[1000] * num_positions,
        ki=[1] * num_positions,
        kd=[20] * num_positions,
        has_reference_acceleration=False,
    )
)

inv_dynamics_controller.set_name("inv_dynamics_controller")
builder.Connect(
    inv_dynamics_controller.get_output_port_control(),
    sim_plant.get_actuation_input_port(),
)

iiwa = sim_plant.GetModelInstanceByName("iiwa")
builder.Connect(
    sim_plant.get_state_output_port(iiwa), 
    inv_dynamics_controller.get_input_port_estimated_state()
)

desired_state_from_position = builder.AddSystem(
    StateInterpolatorWithDiscreteDerivative(
        controller_plant.num_positions(),
        time_step,
        suppress_initial_transient=True,
    )
)

builder.Connect(
    desired_state_from_position.get_output_port(),
    inv_dynamics_controller.get_input_port_desired_state(),
)

s_to_q = builder.AddSystem(ConvertSToQ(np.zeros(controller_plant.num_positions()), RationalForwardKinematics(controller_plant)))
s_trajectory_source = builder.AddSystem(
    TrajectorySource(PiecewisePolynomial(np.zeros(controller_plant.num_positions())))
)
builder.Connect(s_trajectory_source.get_output_port(),
               s_to_q.get_input_port())
builder.Connect(s_to_q.get_output_port(),
               desired_state_from_position.get_input_port())

diagram = builder.Build()
diagram_context = diagram.CreateDefaultContext()
sim_plant_context = sim_plant.GetMyContextFromRoot(diagram_context)
sim_scene_graph_context = sim_scene_graph.GetMyContextFromRoot(diagram_context)


meshcat.SetProperty("/Axes", "visible", False)
meshcat.SetProperty("/Grid", "visible", False)
meshcat.SetProperty("/Background", "visible", False)
meshcat.SetCameraPose(
    camera_in_world = np.array([1.5,1.75,1.5]),
    target_in_world = np.array([0,0,0])
)

diagram.ForcedPublish(diagram_context)

q_star = np.zeros(controller_plant.num_positions())

In [7]:
from scipy.interpolate import CubicSpline
waypoints_q =  np.array([[-2.06706, -0.3944 ,  1.43294,  0.9056 ,  0.53294, -0.8944 ,
        0.74567],
[-2.26706, -0.3944 ,  1.43294,  0.9056 ,  0.73294, -0.9944 ,
        1.04567],
[-2.96706, -0.344 ,  2.83294,  0.9056 ,  0.73294, -0.9944 ,
        1.04567],
[-1.06706,  0.1056 ,  2.43294,  0.9056 ,  0.73294,  0.2056 ,
        1.04567],
[-1.06706,  0.1056 ,  2.43294,  1.1056 ,  1.73294,  0.2056 ,
        0.0567],
[-1.56706,  0.1056 ,  2.43294,  1.2056 ,  1.73294,  0.8056 ,
       -0.25433]])

# waypoints_q[1,:]=np.array([
#     -2.16706, -0.3944 ,  1.43294,  0.9056 ,  0.73294, -1.0944, 1.04567
# ])

t = np.linspace(0, 1, waypoints_q.shape[0])
coord_splines = [CubicSpline(t,waypoints_q[:,i]) for i in range(7)]
t_refined = np.linspace(0,1,30)
waypoints_q_refined = [[c(t) for c in coord_splines]for t in t_refined] 



waypoints_s = np.array([Ratfk.ComputeSValue(w, np.zeros(7)) for w in waypoints_q_refined])


waypoints_q_col = waypoints_q.copy()

waypoints_q_col[1,:]=np.array([-2.22706, -0.3944 ,  1.43294,  0.9056 ,  0.73294, -0.9944 ,
        1.04567])
# waypoints_q_col[1,:]=np.array([
#     -2.06706, -0.3944 ,  1.33294,  0.9056 ,  0.73294, -0.5944 ,  0.54567
# ])
# waypoints_q_col[2,:]=np.array([
#     -2.06706, -0.3944 ,  1.33294,  0.9056 ,  0.73294, -0.5944 ,
#         0.54567
# ])

t = np.linspace(0, 1, waypoints_q_col.shape[0])
coord_splines = [CubicSpline(t,waypoints_q_col[:,i]) for i in range(7)]
t_refined = np.linspace(0,1,30)
waypoints_q_col_refined = [[c(t) for c in coord_splines]for t in t_refined] 

t = np.linspace(0, 1, waypoints_q_col.shape[0])
coord_splines = [CubicSpline(t,waypoints_q_col[:,i]) for i in range(7)]
t_refined = np.linspace(0,1,30)
waypoints_q_col_refined = [[c(t) for c in coord_splines]for t in t_refined] 


waypoints_s_col = np.array([Ratfk.ComputeSValue(w, np.zeros(7)) for w in waypoints_q_col_refined])
#np.array([-2.0706, -0.3244 ,  1.43294,  0.8756 ,  0.73294, -0.9944 ,
#      1.04567])

breaks = np.linspace(0, 1, waypoints_s.shape[0])
samples = waypoints_s

traj_time = 3

# traj_safe = PiecewisePolynomial.CubicShapePreserving(breaks,samples.T)
# polys_safe = np.array([traj_safe.getPolynomialMatrix(i) for i in breaks[:-1]]).squeeze().T
traj_safe = PiecewisePolynomial.CubicWithContinuousSecondDerivatives(breaks,samples.T)
traj_safe.ScaleTime(traj_time)
polys_safe = np.array([traj_safe.getPolynomialMatrix(i) for i in range(len(breaks)-1)]).squeeze().T


breaks = np.linspace(0, 1, waypoints_s_col.shape[0])
samples = waypoints_s_col
# traj_unsafe = PiecewisePolynomial.CubicShapePreserving(breaks,samples.T)
# polys_unsafe = np.array([traj_safe.getPolynomialMatrix(i) for i in breaks[:-1]]).squeeze().T
traj_unsafe = PiecewisePolynomial.CubicWithContinuousSecondDerivatives(breaks,samples.T)
traj_unsafe.ScaleTime(traj_time)

polys_unsafe = np.array([traj_safe.getPolynomialMatrix(i) for i in range(len(breaks)-1)]).squeeze().T



In [8]:
# q = np.zeros(plant.num_positions())
# ik = InverseKinematics(plant, plant_context)
# collision_constraint = ik.AddMinimumDistanceConstraint(0.001, 0.001)
# def eval_cons(q, c, tol):
#     return 1-1*float(c.evaluator().CheckSatisfied(q, tol))
# diagram_col_context = diagram.CreateDefaultContext()
# plant_col_context = diagram.GetMutableSubsystemContext(sim_plant,
#                                                        diagram_col_context)
# scene_graph_col_context = diagram.GetMutableSubsystemContext(
#         scene_graph, diagram_col_context
#     )
query_port = sim_scene_graph.get_query_output_port()

def SetSimPlantIiwaPosition(q, context = sim_plant_context):
    cur_q = sim_plant.GetPositions(context)
    cur_q[:controller_plant.num_positions()] = q
    sim_plant.SetPositions(context, cur_q)
    
def SetSimPlantIiwaPositionS(s, context = sim_plant_context):
    s = np.array(s)
    q = Ratfk.ComputeQValue(s, q_star)
    SetSimPlantIiwaPosition(q, context)
    
def check_collision_q_by_query(q):
    SetSimPlantIiwaPosition(q)
    query_object = query_port.Eval(sim_scene_graph_context)
    return 1 if query_object.HasCollisions() else 0
    

def check_collision_s_by_query(s):
    s = np.array(s)
    q = Ratfk.ComputeQValue(s, q_star)
    return check_collision_q_by_query(q)


    
col_func_handle = check_collision_q_by_query#partial(eval_cons, c=collision_constraint, tol=0.01)
# col_func_handle = partial(eval_cons, c=collision_constraint, tol=0.01)

col_shunk_col =  Rgba(0.8, 0.0, 0, 0.5)    
col_shunk_free =  Rgba(0.0, 0.8, 0.5, 0.5)   


def showres_s(s_vis):
    showres(Ratfk.ComputeQValue(s_vis, np.zeros(7)))

def showres(qvis):
    SetSimPlantIiwaPosition(qvis)
    
    col = col_func_handle(qvis)
#     shunk = sim_plant.get_body(pydrake.multibody.tree.BodyIndex(9))
#     shunk = sim_plant.GetBodyByName("body")
#     tf_tot = shunk.EvalPoseInWorld(sim_plant_context)
#     tf = tf_tot.translation()
    
#    # tf_tot= plant.EvalBodyPoseInWorld(plant_context, plant.get_body(pydrake.multibody.tree.BodyIndex(7)))
#     #tf = tf_tot.translation() - tf_tot.GetAsMatrix4()[:3,:3][:,1] *0.15
#     if col:
#         meshcat.SetObject(f"/drake/visualizer/shunk",
#                                    Sphere(0.05),
#                                    col_shunk_col)
#     else:
#         meshcat.SetObject(f"/drake/visualizer/shunk",
#                                    Sphere(0.05),
#                                    col_shunk_free)
#     meshcat.SetTransform(f"/drake/visualizer/shunk",
#                                    RigidTransform(RotationMatrix(),
#                                                   tf))
    
    diagram.ForcedPublish(diagram_context)



In [9]:
import C_Iris_Examples.visualization_utils as vis_utils
q_star = np.zeros(controller_plant.num_positions())
vis_bundle = vis_utils.VisualizationBundle(
    diagram, diagram_context, sim_plant, sim_plant_context,
    Ratfk, meshcat, q_star
)
pos_set_fun_s = SetSimPlantIiwaPositionS

num_points = int(1e4)
safe_color=Rgba(0,1,0,1)
vis_utils.visualize_s_space_trajectory(
    vis_bundle, traj_safe,
    sim_plant.GetBodyByName("body"), #shunk,
    "/safe_traj",
    vis_utils.TrajectoryVisualizationOptions(start_size = 0.001,
                                            end_size=0.001,
                                            start_color=safe_color,
                                            end_color=safe_color,
                                            path_color=safe_color,
                                            num_points = num_points),
    pos_set_fun_s
)

unsafe_color = Rgba(1,0,0,1)
vis_utils.visualize_s_space_trajectory(
    vis_bundle, traj_unsafe,
    sim_plant.GetBodyByName("body"), #shunk,
    "/unsafe_traj",
    vis_utils.TrajectoryVisualizationOptions(start_size = 0.001,
                                            end_size=0.001,
                                            start_color=unsafe_color,
                                            end_color=unsafe_color,
                                            path_color=unsafe_color,
                                            num_points = num_points),
    pos_set_fun_s
)

In [10]:
# for sa in waypoints_s_col_refined:
#     #qa = #Ratfk.ComputeQValue(ta, np.zeros(3))
#     showres_s(sa)
#     if col_func_handle(Ratfk.ComputeQValue(sa, np.zeros(7))):
#         print('col')
# #         break
#     time.sleep(0.01)

## Now simulate the trajectory

In [11]:
from pydrake.all import Simulator
def simulate_s_traj(s_traj, end_time = None):
    s_trajectory_source.UpdateTrajectory(s_traj)
    simulator = Simulator(diagram)
    simulator.Initialize()
    simulator_context = simulator.get_mutable_context()
    plant_context = sim_plant.GetMyContextFromRoot(simulator_context)
#     print(sim_plant.GetPositions(plant_context))
    SetSimPlantIiwaPositionS(s_traj.value(0).flatten(),
                            context = plant_context)
#     print(sim_plant.GetPositions(plant_context))

#     simulator.set_target_realtime_rate(1.)
    meshcat.StartRecording()
    if end_time is None:
        end_time = s_traj.end_time()
    simulator.AdvanceTo(end_time)
    meshcat.PublishRecording()
#     print(sim_plant.GetPositions(plant_context))


In [12]:
simulate_s_traj(traj_safe, end_time =5)
time.sleep(2)


In [13]:
simulate_s_traj(traj_unsafe, end_time = 5)

In [13]:
raise ValueError()

ValueError: 

In [None]:
# 
# plane_order = 2
# max_degree = 3
# t0 = time.time()
# cspace_free_path = CspaceFreePath(
#     plant,
#     scene_graph,
#     q_star,
#     maximum_path_degree=max_degree,
#     plane_order=plane_order,
# )
# t1 = time.time()
# print(f"Time to build collision checker {t1-t0}")

In [None]:
len(cspace_free_path.separating_planes())

In [None]:
# for plane in cspace_free_path.separating_planes():
#     pos_body_idx = plane.positive_side_geometry.body_index()
#     neg_body_idx = plane.negative_side_geometry.body_index()
#     pos_body_name = plant.get_body(pos_body_idx).name()
#     neg_body_name = plant.get_body(neg_body_idx).name()
#     print(pos_body_name, neg_body_name)
print(len(cspace_free_path.separating_planes()))

In [None]:
cert_options = CspaceFreePath.FindSeparationCertificateGivenPathOptions()
cert_options.terminate_segment_certification_at_failure = False
cert_options.num_threads = -1
cert_options.verbose = False
cert_options.solver_id = MosekSolver.id()
cert_options.solver_options = SolverOptions()
cert_options.terminate_path_certification_at_failure = False

In [None]:
q = waypoints_q[1,:].copy()
def handle_slider_change(change, idx):
    q[idx] = change['new']
    showres(q)    
sliders = []
for i in range(7):
    q_low = sim_plant.GetPositionLowerLimits()[i]
    q_high = sim_plant.GetPositionUpperLimits()[i]
    sliders.append(widgets.FloatSlider(min=q_low, max=q_high, value=q[i], description=f"q{i}"))

    
scaler = 1 #np.array([0.8, 1., 0.8, 1, 0.8, 1, 0.8]) #do you even geometry bro?
q_min = np.array(q_low)*scaler
q_max = np.array(q_high)*scaler
q_diff = q_max-q_min

idx = 0
for slider in sliders:
    slider.observe(partial(handle_slider_change, idx = idx), names='value')
    idx+=1

display(sliders[0])
for slider in sliders:
    display(slider)

In [None]:
print(repr(
    sim_plant.GetPositions(sim_plant_context)[:7]
))

In [None]:
t0 = time.time()
stats, res = cspace_free_path.FindSeparationCertificateGivenPath(polys,
                                                                 set(),
                                                                 cert_options)
t1 = time.time()
print(f"Time to certify path {t1-t0}")

In [None]:
for s in stats:
    print(s.certified_safe())
    

In [None]:
print(f"Total safe solver time = {sum([s.total_time_to_solve_progs()/20 for s in stats])}")

In [None]:
t0 = time.time()
stats, res = cspace_free_path.FindSeparationCertificateGivenPath(polys_unsafe,
                                                                 set(),
                                                                 cert_options)
t1 = time.time()
print(f"Time to certify path {t1-t0}")

In [None]:
for s in stats:
    print(s.certified_safe())
    

In [None]:
print(f"Total unsafe solver time = {sum([s.total_time_to_solve_progs()/20 for s in stats])}")

In [None]:
print(sum([s.total_time_to_solve_progs()/20 for s in stats]))

In [None]:
for s in stats:
    print(s.total_time_to_certify())
    print(s.total_time_to_solve_progs())
    print(s.total_time_building_progs())
    print()

In [None]:
print(stats[0].total_time_to_certify_pair[0])
print(stats[0].time_to_build_prog)
print(stats[0].time_to_solve_prog[0])