In [None]:
%load_ext autoreload
# Enable autoreload for all modules
%autoreload 2

In [None]:
from iris_environments.environments import get_environment_builder
import numpy as np
import ipywidgets as widgets
from functools import partial
from pydrake.all import (RigidTransform, Rgba, Sphere, RotationMatrix)

In [None]:
names = ['2DOFFLIPPER',
         '3DOFFLIPPER', 
         '5DOFUR5', 
         '6DOFUR5',
         '7DOFIIWA', 
         '7DOF4SHELVES', 
         '7DOFBINS', 
         '14DOFIIWAS']

currname = names[0]
plant_builder = get_environment_builder(currname)
plant, scene_graph, diagram, diagram_context, plant_context, meshcat = plant_builder(usemeshcat=True)

scene_graph_context = scene_graph.GetMyMutableContextFromRoot(
    diagram_context)

In [None]:
#make configuration file 
import os
seed_point_file = 'benchmarks/'+currname+'seedpoints.yml'
if seed_point_file.split('/')[1] in os.listdir('benchmarks'):
    file = open(seed_point_file, 'w') 
else:
    file = open(seed_point_file, 'w')
    file.write('seedpoints:\n')

# Run this cell and then use the sliders and the button to save the seed points

In [None]:


q = np.zeros(plant.num_positions()) 
sliders = []
for i in range(plant.num_positions()):
    q_low = plant.GetPositionLowerLimits()[i]
    q_high = plant.GetPositionUpperLimits()[i]
    sliders.append(widgets.FloatSlider(min=q_low, max=q_high, value=0, description=f"q{i}"))

col_col =  Rgba(0.8, 0.0, 0, 0.5)    
col_free =  Rgba(0.0, 0.8, 0.5, 0.5) 
def showres(qvis):
    plant.SetPositions(plant_context, qvis)
    diagram.ForcedPublish(diagram_context)
    query = plant.get_geometry_query_input_port().Eval(plant_context)
    col = query.HasCollisions()
    if col:
        meshcat.SetObject(f"/drake/visualizer/shunk",
                                   Sphere(0.2),
                                   col_col)
    else:
        meshcat.SetObject(f"/drake/visualizer/shunk",
                                   Sphere(0.2),
                                   col_free)
    meshcat.SetTransform(f"/drake/visualizer/shunk",
                                   RigidTransform(RotationMatrix(),
                                                  np.array([0,0,2])))
    return col

def handle_slider_change(change, idx):
    q[idx] = change['new']
    showres(q)
idx = 0
for slider in sliders:
    slider.observe(partial(handle_slider_change, idx = idx), names='value')
    idx+=1

for slider in sliders:
    display(slider)

def write_seed_point_to_file(button):
    col = showres(q)
    if not col:
        line  = '- ['
        for a in q[:-1]:
            line+= str(a)+', '
        line+= str(q[-1])+']\n'
        file.write(line)
        file.flush()
        #write seedpoint to file
    else:
        raise ValueError("That point is in collision")

button = widgets.Button(description="Save")

# Attach the function to the button's click event
button.on_click(write_seed_point_to_file)

# Display the button
display(button)

In [None]:
file.close()

# Inspect seed points

In [None]:
import yaml
import time

with open(seed_point_file, 'r') as f:
    seed_points = yaml.safe_load(f)
seed_points = np.array(seed_points['seedpoints'])

for i, s in enumerate(seed_points):
    showres(s)
    print(f" point {i+1} / {len(seed_points)}")
    time.sleep(2)
