In [None]:
# Ball Sorting Robot: Pick and Throw Red/Blue Balls into Matching Bins
from pydrake.all import *
from manipulation.station import LoadScenario, MakeHardwareStation
from pydrake.geometry import Rgba
from pydrake.systems.sensors import RgbdSensor, CameraConfig
from scipy.optimize import least_squares
from pathlib import Path
import numpy as np
import tempfile
import os
import sys

sys.path.append(str(Path.cwd()))
from throw_helpers import create_q_knots

meshcat = StartMeshcat()
print(f"Meshcat: {meshcat.web_url()}")

INFO:drake:Meshcat listening for connections at http://localhost:7001


Meshcat: http://localhost:7001


In [None]:
# === CONFIGURATION ===
ball_radius = 0.05
red_pos = np.array([0.0, -0.5, ball_radius])
blue_pos = np.array([-0.15, -0.55, ball_radius])

BIN_CENTER_Y = 1.65
BIN_OFFSET_X = 0.35
red_bin_pos = np.array([BIN_OFFSET_X, BIN_CENTER_Y, 0.0])
blue_bin_pos = np.array([-BIN_OFFSET_X, BIN_CENTER_Y, 0.0])

RED_THROW_ANGLE = np.radians(-15)
BLUE_THROW_ANGLE = np.radians(15)

GRIPPER_OPEN, GRIPPER_CLOSED = 0.107, 0.0
GRASP_HEIGHT, APPROACH_HEIGHT = 0.12, 0.1
THROW_DURATION, RELEASE_FRAC = 0.4, 0.5

In [None]:
# === HELPER FUNCTIONS ===
def depth_to_points(depth, rgb, cam_info):
    h, w = depth.shape
    u, v = np.meshgrid(np.arange(w), np.arange(h))
    valid = (depth > 0.01) & (depth < 5.0)
    z = depth[valid]
    x = (u[valid] - cam_info.center_x()) * z / cam_info.focal_x()
    y = (v[valid] - cam_info.center_y()) * z / cam_info.focal_y()
    return np.column_stack([x, y, z]), rgb[valid] / 255.0

def make_camera_transform(pos, target):
    fwd = (target - pos) / np.linalg.norm(target - pos)
    right = np.cross(fwd, [0,0,1]); right /= np.linalg.norm(right)
    down = np.cross(fwd, right)
    return RigidTransform(RotationMatrix(np.column_stack([right, down, fwd])), pos)

def fit_sphere(points, radius):
    res = least_squares(lambda c: np.linalg.norm(points - c, axis=1) - radius, np.mean(points, axis=0))
    return res.x

def detect_ball(points, colors, color_type):
    if color_type == "red":
        score = colors[:,0] - np.maximum(colors[:,1], colors[:,2])
    else:
        score = colors[:,2] - np.maximum(colors[:,0], colors[:,1])
    
    n_top = int(max(50, len(points) * 0.02))
    if len(score) < n_top:
        return None
    thresh = np.sort(score)[-n_top]
    pts = points[score >= thresh].copy()
    
    if len(pts) < 10:
        return None
    for _ in range(3):
        c = np.median(pts, axis=0)
        mask = np.linalg.norm(pts - c, axis=1) < 0.05
        if np.sum(mask) < 10:
            break
        pts = pts[mask]
    
    center = fit_sphere(pts, ball_radius)
    return np.array([center[0], center[1], ball_radius])

In [None]:
# === SCENE SETUP ===
temp_dir = tempfile.mkdtemp()
for name, color in [("red", "0.9 0.1 0.1"), ("blue", "0.2 0.4 0.9")]:
    with open(f"{temp_dir}/{name}_ball.sdf", 'w') as f:
        f.write(f'<?xml version="1.0"?><sdf version="1.7"><model name="{name}_ball"><link name="ball"><inertial><mass>0.1</mass><inertia><ixx>0.0001</ixx><ixy>0</ixy><ixz>0</ixz><iyy>0.0001</iyy><iyz>0</iyz><izz>0.0001</izz></inertia></inertial><visual name="v"><geometry><sphere><radius>0.05</radius></sphere></geometry><material><diffuse>{color} 1.0</diffuse></material></visual><collision name="c"><geometry><sphere><radius>0.05</radius></sphere></geometry></collision></link></model></sdf>')

scenario_yaml = f"""
directives:
- add_model:
    name: iiwa
    file: package://drake_models/iiwa_description/sdf/iiwa7_no_collision.sdf
    default_joint_positions:
      iiwa_joint_1: [-1.57]
      iiwa_joint_2: [0.1]
      iiwa_joint_3: [0]
      iiwa_joint_4: [-1.2]
      iiwa_joint_5: [0]
      iiwa_joint_6: [1.6]
      iiwa_joint_7: [0]
- add_weld:
    parent: world
    child: iiwa::iiwa_link_0
- add_model:
    name: wsg
    file: package://manipulation/hydro/schunk_wsg_50_with_tip.sdf
- add_weld:
    parent: iiwa::iiwa_link_7
    child: wsg::body
    X_PC:
      translation: [0, 0, 0.09]
      rotation: !Rpy {{ deg: [90, 0, 90] }}
- add_model:
    name: table
    file: file://{Path.cwd()}/assets/table_extended.sdf
- add_weld:
    parent: world
    child: table::table_link
    X_PC:
      translation: [0.0, 1.0, -0.05]
- add_model:
    name: bin_red
    file: file://{Path.cwd()}/assets/bin_red.sdf
- add_weld:
    parent: world
    child: bin_red::bin_link
    X_PC:
      translation: [{red_bin_pos[0]}, {red_bin_pos[1]}, 0.0]
- add_model:
    name: bin_blue
    file: file://{Path.cwd()}/assets/bin_blue.sdf
- add_weld:
    parent: world
    child: bin_blue::bin_link
    X_PC:
      translation: [{blue_bin_pos[0]}, {blue_bin_pos[1]}, 0.0]
- add_model:
    name: red_ball
    file: file://{temp_dir}/red_ball.sdf
    default_free_body_pose:
      ball:
        translation: [{red_pos[0]}, {red_pos[1]}, {red_pos[2]}]
- add_model:
    name: blue_ball
    file: file://{temp_dir}/blue_ball.sdf
    default_free_body_pose:
      ball:
        translation: [{blue_pos[0]}, {blue_pos[1]}, {blue_pos[2]}]

model_drivers:
  iiwa: !IiwaDriver
    control_mode: position_only
    hand_model_name: wsg
  wsg: !SchunkWsgDriver {{}}
"""

# Camera positions
workspace = np.array([0.0, -0.5, 0.05])
cam_positions = [
    np.array([0.4, -0.15, 0.5]),
    np.array([-0.35, -0.9, 0.25]),
    np.array([-0.5, -0.4, 0.4]),
    np.array([0.5, -0.6, 0.35])
]
cam_transforms = [make_camera_transform(p, workspace) for p in cam_positions]

print(f"Red bin: {red_bin_pos[:2]}, Blue bin: {blue_bin_pos[:2]}")

Red bin: [0.35 1.65], Blue bin: [-0.35  1.65]


In [None]:
# === INITIAL PERCEPTION ===
meshcat.Delete()
builder = DiagramBuilder()
station = MakeHardwareStation(LoadScenario(data=scenario_yaml), meshcat=meshcat)
builder.AddSystem(station)
plant = station.GetSubsystemByName("plant")
scene_graph = station.GetSubsystemByName("scene_graph")

# Get initial gripper pose
ctx = station.CreateDefaultContext()
X_WG_init = plant.EvalBodyPoseInWorld(plant.GetMyContextFromRoot(ctx), plant.GetBodyByName("body"))

# Add cameras
scene_graph.AddRenderer("r", MakeRenderEngineVtk(RenderEngineVtkParams()))
cam_cfg = CameraConfig(); cam_cfg.width, cam_cfg.height, cam_cfg.renderer_name = 640, 480, "r"
sensors, infos = [], []
for X in cam_transforms:
    cc, dc = cam_cfg.MakeCameras()
    s = builder.AddSystem(RgbdSensor(scene_graph.world_frame_id(), X, cc, dc))
    builder.Connect(station.GetOutputPort("query_object"), s.query_object_input_port())
    sensors.append(s); infos.append(cc.core().intrinsics())

diagram = builder.Build()
context = diagram.CreateDefaultContext()

# Capture and merge point clouds
all_pts, all_cols = [], []
for s, info, X in zip(sensors, infos, cam_transforms):
    sc = s.GetMyContextFromRoot(context)
    rgb = s.color_image_output_port().Eval(sc).data[:,:,:3].astype(np.uint8)
    depth = s.depth_image_32F_output_port().Eval(sc).data.squeeze()
    pts, cols = depth_to_points(depth, rgb, info)
    all_pts.append(np.array([X @ p for p in pts]))
    all_cols.append(cols)

pts = np.vstack(all_pts); cols = np.vstack(all_cols)
above = pts[:, 2] > 0.02
pts, cols = pts[above], cols[above]

# Detect balls
det_red = detect_ball(pts, cols, "red")
det_blue = detect_ball(pts, cols, "blue")
print(f"Red: {det_red}, err={np.linalg.norm(det_red-red_pos)*1000:.1f}mm")
print(f"Blue: {det_blue}, err={np.linalg.norm(det_blue-blue_pos)*1000:.1f}mm")



Red: [ 2.24477300e-04 -4.99852757e-01  5.00000000e-02], err=0.3mm
Blue: [-0.14975433 -0.55002004  0.05      ], err=0.2mm


In [None]:
# === TRAJECTORY GENERATION ===
def make_trajectory(ball_pos, throw_angle, t_offset=0):
    X_grasp = RigidTransform(p=ball_pos) @ RigidTransform(RotationMatrix.MakeXRotation(-np.pi/2), [0, 0, GRASP_HEIGHT])
    X_approach = RigidTransform(p=[0, 0, APPROACH_HEIGHT]) @ X_grasp
    
    # Pick trajectory
    times = [0, 1.0, 1.5, 2.0, 2.5]
    poses = [X_WG_init, X_approach, X_grasp, X_grasp, X_approach]
    pose_traj = PiecewisePose.MakeLinear(times, poses)
    t_samples = np.linspace(0, 2.5, 50)
    q_pick = create_q_knots(plant, [pose_traj.GetPose(t) for t in t_samples])
    
    # Throw trajectory
    ja1 = np.pi/2 + throw_angle - np.pi
    pre = np.array([ja1, 0, 0, 1.9, 0, -1.9, np.pi])
    end = np.array([ja1, 0, 0, 0.4, 0, -0.4, np.pi])
    t_throw = np.linspace(0, THROW_DURATION, 30)
    q_throw = [pre + (t/THROW_DURATION)*(end-pre) for t in t_throw]
    
    # Combine
    t_pick_end = 2.5
    t_full = list(t_samples) + [t_pick_end + 1.5] + list(t_pick_end + 1.5 + t_throw[1:])
    q_full = list(q_pick) + [pre] + q_throw[1:]
    
    # Offset times
    t_full = [t + t_offset for t in t_full]
    q_traj = PiecewisePolynomial.CubicShapePreserving(t_full, np.array(q_full).T)
    
    # Gripper
    t_close = 2.0 + t_offset
    t_release = t_pick_end + 1.5 + RELEASE_FRAC * THROW_DURATION + t_offset
    t_end = q_traj.end_time()
    g_traj = PiecewisePolynomial.FirstOrderHold(
        [t_offset, t_close, t_close+0.01, t_release, t_release+0.01, t_end],
        np.array([[GRIPPER_OPEN, GRIPPER_OPEN, GRIPPER_CLOSED, GRIPPER_CLOSED, GRIPPER_OPEN, GRIPPER_OPEN]])
    )
    return q_traj, g_traj, t_end

# Generate trajectories
q_red, g_red, t_red_end = make_trajectory(det_red, RED_THROW_ANGLE, 0)
q_blue, g_blue, t_blue_end = make_trajectory(det_blue, BLUE_THROW_ANGLE, t_red_end + 1.0)

# Concatenate into single trajectory
t_total = t_blue_end
print(f"Red throw: 0 - {t_red_end:.1f}s")
print(f"Blue throw: {t_red_end+1:.1f} - {t_blue_end:.1f}s")
print(f"Total: {t_total:.1f}s")

Red throw: 0 - 4.4s
Blue throw: 5.4 - 9.8s
Total: 9.8s


In [None]:
# === RUN SIMULATION ===
meshcat.Delete()
builder = DiagramBuilder()
station = MakeHardwareStation(LoadScenario(data=scenario_yaml), meshcat=meshcat)
builder.AddSystem(station)
plant = station.GetSubsystemByName("plant")

# Custom trajectory source that switches between red and blue
class DualTrajectorySource(LeafSystem):
    def __init__(self, q_red, g_red, q_blue, g_blue, t_switch):
        LeafSystem.__init__(self)
        self.q_red, self.g_red = q_red, g_red
        self.q_blue, self.g_blue = q_blue, g_blue
        self.t_switch = t_switch
        self.DeclareVectorOutputPort("q", 7, self.calc_q)
        self.DeclareVectorOutputPort("g", 1, self.calc_g)
    
    def calc_q(self, context, output):
        t = context.get_time()
        if t < self.t_switch:
            output.SetFromVector(self.q_red.value(t).flatten())
        else:
            output.SetFromVector(self.q_blue.value(t).flatten())
    
    def calc_g(self, context, output):
        t = context.get_time()
        if t < self.t_switch:
            output.SetFromVector(self.g_red.value(t).flatten())
        else:
            output.SetFromVector(self.g_blue.value(t).flatten())

traj_source = builder.AddSystem(DualTrajectorySource(q_red, g_red, q_blue, g_blue, t_red_end + 0.5))
builder.Connect(traj_source.GetOutputPort("q"), station.GetInputPort("iiwa.position"))
builder.Connect(traj_source.GetOutputPort("g"), station.GetInputPort("wsg.position"))

diagram = builder.Build()
sim = Simulator(diagram)
sim.set_target_realtime_rate(1.0)
ctx = sim.get_mutable_context()

red_body = plant.GetBodyByName("ball", plant.GetModelInstanceByName("red_ball"))
blue_body = plant.GetBodyByName("ball", plant.GetModelInstanceByName("blue_ball"))

meshcat.StartRecording()
print("Running simulation...")
while ctx.get_time() < t_total + 2:
    sim.AdvanceTo(ctx.get_time() + 0.02)
meshcat.StopRecording()
meshcat.PublishRecording()

# Results
pctx = plant.GetMyContextFromRoot(ctx)
red_final = plant.EvalBodyPoseInWorld(pctx, red_body).translation()
blue_final = plant.EvalBodyPoseInWorld(pctx, blue_body).translation()

red_dist = np.linalg.norm(red_final[:2] - red_bin_pos[:2])
blue_dist = np.linalg.norm(blue_final[:2] - blue_bin_pos[:2])

print(f"\n=== RESULTS ===")
print(f"Red ball:  {red_final[:2]} -> bin {red_bin_pos[:2]}, dist={red_dist*100:.1f}cm, IN={'YES' if red_dist<0.15 else 'NO'}")
print(f"Blue ball: {blue_final[:2]} -> bin {blue_bin_pos[:2]}, dist={blue_dist*100:.1f}cm, IN={'YES' if blue_dist<0.15 else 'NO'}")
if red_dist < 0.15 and blue_dist < 0.15:
    print("\n*** PERFECT SORT! ***")

Running simulation...

=== RESULTS ===
Red ball:  [0.36910443 1.61991478] -> bin [0.35 1.65], dist=3.6cm, IN=YES
Blue ball: [-0.29865488  1.57215986] -> bin [-0.35  1.65], dist=9.3cm, IN=YES

*** PERFECT SORT! ***
