In [1]:
%load_ext autoreload
# Enable autoreload for all modules
%autoreload 2

In [2]:
from iris_environments.environments import get_environment_builder
import numpy as np
import ipywidgets as widgets
from functools import partial
from pydrake.all import (RigidTransform, Rgba, Sphere, RotationMatrix)

In [3]:
from iris_environments.environments import env_names
currname = 'ALLEGRO'
plant_builder = get_environment_builder(currname)
plant, scene_graph, diagram, diagram_context, plant_context, meshcat = plant_builder(usemeshcat=True)

scene_graph_context = scene_graph.GetMyMutableContextFromRoot(
    diagram_context)

INFO:drake:Meshcat listening for connections at http://localhost:7000


In [4]:
#make configuration file 
import os
seed_point_file = 'benchmarks/seedpoints/'+currname+'.yml'
if seed_point_file.split('/')[2] in os.listdir('benchmarks/seedpoints'):
    file = open(seed_point_file, 'a') 
else:
    file = open(seed_point_file, 'w')
    file.write('seedpoints:\n')

# Run this cell and then use the sliders and the button to save the seed points

In [6]:


q = np.zeros(plant.num_positions()) 
sliders = []
for i in range(plant.num_positions()):
    q_low = plant.GetPositionLowerLimits()[i]*0.99
    q_high = plant.GetPositionUpperLimits()[i]*0.99
    sliders.append(widgets.FloatSlider(min=q_low, max=q_high, value=0, step=0.001, description=f"q{i}"))

col_col =  Rgba(0.8, 0.0, 0, 0.5)    
col_free =  Rgba(0.0, 0.8, 0.5, 0.5) 
def showres(qvis):
    plant.SetPositions(plant_context, qvis)
    diagram.ForcedPublish(diagram_context)
    query = plant.get_geometry_query_input_port().Eval(plant_context)
    col = query.HasCollisions()
    if col:
        meshcat.SetObject(f"/drake/visualizer/shunk",
                                   Sphere(0.2),
                                   col_col)
    else:
        meshcat.SetObject(f"/drake/visualizer/shunk",
                                   Sphere(0.2),
                                   col_free)
    meshcat.SetTransform(f"/drake/visualizer/shunk",
                                   RigidTransform(RotationMatrix(),
                                                  np.array([0,0,2])))
    return col

def handle_slider_change(change, idx):
    q[idx] = change['new']
    showres(q)
idx = 0
for slider in sliders:
    slider.observe(partial(handle_slider_change, idx = idx), names='value')
    idx+=1

for slider in sliders:
    display(slider)

def write_seed_point_to_file(button):
    col = showres(q)
    if not col:
        line  = '- ['
        for a in q[:-1]:
            line+= str(a)+', '
        line+= str(q[-1])+']\n'
        file.write(line)
        file.flush()
        #write seedpoint to file
    else:
        raise ValueError("That point is in collision")

button = widgets.Button(description="Save")

# Attach the function to the button's click event
button.on_click(write_seed_point_to_file)

# Display the button
display(button)

FloatSlider(value=0.0, description='q0', max=0.4653, min=-0.4653, step=0.001)

FloatSlider(value=0.0, description='q1', max=1.5939, min=-0.19404000000000002, step=0.001)

FloatSlider(value=0.0, description='q2', max=1.69191, min=-0.17226, step=0.001)

FloatSlider(value=0.0, description='q3', max=1.60182, min=-0.22473, step=0.001)

FloatSlider(value=0.26037, description='q4', max=1.38204, min=0.26037, step=0.001)

FloatSlider(value=0.0, description='q5', max=1.15137, min=-0.10395, step=0.001)

FloatSlider(value=0.0, description='q6', max=1.62756, min=-0.18711, step=0.001)

FloatSlider(value=0.0, description='q7', max=1.70181, min=-0.16038, step=0.001)

FloatSlider(value=0.0, description='q8', max=0.4653, min=-0.4653, step=0.001)

FloatSlider(value=0.0, description='q9', max=1.5939, min=-0.19404000000000002, step=0.001)

FloatSlider(value=0.0, description='q10', max=1.69191, min=-0.17226, step=0.001)

FloatSlider(value=0.0, description='q11', max=1.60182, min=-0.22473, step=0.001)

FloatSlider(value=0.0, description='q12', max=0.4653, min=-0.4653, step=0.001)

FloatSlider(value=0.0, description='q13', max=1.5939, min=-0.19404000000000002, step=0.001)

FloatSlider(value=0.0, description='q14', max=1.69191, min=-0.17226, step=0.001)

FloatSlider(value=0.0, description='q15', max=1.60182, min=-0.22473, step=0.001)

Button(description='Save', style=ButtonStyle())

In [6]:
file.close()

# Inspect seed points

In [7]:
import yaml
import time

with open(seed_point_file, 'r') as f:
    seed_points = yaml.safe_load(f)
seed_points = np.array(seed_points['seedpoints'])

for i, s in enumerate(seed_points):
    showres(s)
    print(f" point {i+1} / {len(seed_points)}")
    time.sleep(2)


 point 1 / 10
 point 2 / 10
 point 3 / 10
 point 4 / 10
 point 5 / 10
 point 6 / 10
 point 7 / 10
 point 8 / 10
 point 9 / 10
 point 10 / 10


In [14]:
showres(seed_points[2,:])

False

In [9]:
q = seed_points[2,:]

In [10]:
q

array([ 0.434544, -0.6944  ])