In [2]:
# Ball Sorting Robot: Pick and Throw Red/Blue Balls into Matching Bins
from pydrake.all import *
from manipulation.station import LoadScenario, MakeHardwareStation
from pydrake.geometry import Rgba
from pydrake.systems.sensors import RgbdSensor, CameraConfig
from scipy.optimize import least_squares
from pathlib import Path
import numpy as np
import tempfile
import os
import sys

sys.path.append(str(Path.cwd()))
from throw_helpers import create_q_knots

meshcat = StartMeshcat()
print(f"Meshcat: {meshcat.web_url()}")

INFO:drake:Meshcat listening for connections at http://localhost:7001


Meshcat: http://localhost:7001


In [3]:
# === CONFIGURATION ===
ball_radius = 0.05

# Ball positions - 2 red, 2 blue scattered on table
red_positions = [
    np.array([0.0, -0.5, ball_radius]),
    np.array([0.15, -0.6, ball_radius]),
]
blue_positions = [
    np.array([-0.15, -0.55, ball_radius]),
    np.array([-0.05, -0.65, ball_radius]),
]

BIN_CENTER_Y = 1.65
BIN_OFFSET_X = 0.35
red_bin_pos = np.array([BIN_OFFSET_X, BIN_CENTER_Y, 0.0])
blue_bin_pos = np.array([-BIN_OFFSET_X, BIN_CENTER_Y, 0.0])

RED_THROW_ANGLE = np.radians(-15)
BLUE_THROW_ANGLE = np.radians(15)

GRIPPER_OPEN, GRIPPER_CLOSED = 0.107, 0.0
GRASP_HEIGHT, APPROACH_HEIGHT = 0.12, 0.1
THROW_DURATION, RELEASE_FRAC = 0.4, 0.5

In [4]:
# === HELPER FUNCTIONS ===
def depth_to_points(depth, rgb, cam_info):
    h, w = depth.shape
    u, v = np.meshgrid(np.arange(w), np.arange(h))
    valid = (depth > 0.01) & (depth < 5.0)
    z = depth[valid]
    x = (u[valid] - cam_info.center_x()) * z / cam_info.focal_x()
    y = (v[valid] - cam_info.center_y()) * z / cam_info.focal_y()
    return np.column_stack([x, y, z]), rgb[valid] / 255.0

def make_camera_transform(pos, target):
    fwd = (target - pos) / np.linalg.norm(target - pos)
    right = np.cross(fwd, [0,0,1]); right /= np.linalg.norm(right)
    down = np.cross(fwd, right)
    return RigidTransform(RotationMatrix(np.column_stack([right, down, fwd])), pos)

def fit_sphere(points, radius):
    res = least_squares(lambda c: np.linalg.norm(points - c, axis=1) - radius, np.mean(points, axis=0))
    return res.x

def detect_ball(points, colors, color_type):
    if color_type == "red":
        score = colors[:,0] - np.maximum(colors[:,1], colors[:,2])
    else:
        score = colors[:,2] - np.maximum(colors[:,0], colors[:,1])
    
    n_top = int(max(50, len(points) * 0.02))
    if len(score) < n_top:
        return None
    thresh = np.sort(score)[-n_top]
    pts = points[score >= thresh].copy()
    
    if len(pts) < 10:
        return None
    for _ in range(3):
        c = np.median(pts, axis=0)
        mask = np.linalg.norm(pts - c, axis=1) < 0.05
        if np.sum(mask) < 10:
            break
        pts = pts[mask]
    
    center = fit_sphere(pts, ball_radius)
    return np.array([center[0], center[1], ball_radius])

In [5]:
# === SCENE SETUP ===
temp_dir = tempfile.mkdtemp()

# Create ball SDFs
for name, color in [("red", "0.9 0.1 0.1"), ("blue", "0.2 0.4 0.9")]:
    with open(f"{temp_dir}/{name}_ball.sdf", 'w') as f:
        f.write(f'<?xml version="1.0"?><sdf version="1.7"><model name="{name}_ball"><link name="ball"><inertial><mass>0.1</mass><inertia><ixx>0.0001</ixx><ixy>0</ixy><ixz>0</ixz><iyy>0.0001</iyy><iyz>0</iyz><izz>0.0001</izz></inertia></inertial><visual name="v"><geometry><sphere><radius>0.05</radius></sphere></geometry><material><diffuse>{color} 1.0</diffuse></material></visual><collision name="c"><geometry><sphere><radius>0.05</radius></sphere></geometry></collision></link></model></sdf>')

# Build ball directives dynamically
ball_directives = ""
for i, pos in enumerate(red_positions):
    ball_directives += f"""
- add_model:
    name: red_ball_{i}
    file: file://{temp_dir}/red_ball.sdf
    default_free_body_pose:
      ball:
        translation: [{pos[0]}, {pos[1]}, {pos[2]}]"""

for i, pos in enumerate(blue_positions):
    ball_directives += f"""
- add_model:
    name: blue_ball_{i}
    file: file://{temp_dir}/blue_ball.sdf
    default_free_body_pose:
      ball:
        translation: [{pos[0]}, {pos[1]}, {pos[2]}]"""

scenario_yaml = f"""
directives:
- add_model:
    name: iiwa
    file: package://drake_models/iiwa_description/sdf/iiwa7_no_collision.sdf
    default_joint_positions:
      iiwa_joint_1: [-1.57]
      iiwa_joint_2: [0.1]
      iiwa_joint_3: [0]
      iiwa_joint_4: [-1.2]
      iiwa_joint_5: [0]
      iiwa_joint_6: [1.6]
      iiwa_joint_7: [0]
- add_weld:
    parent: world
    child: iiwa::iiwa_link_0
- add_model:
    name: wsg
    file: package://manipulation/hydro/schunk_wsg_50_with_tip.sdf
- add_weld:
    parent: iiwa::iiwa_link_7
    child: wsg::body
    X_PC:
      translation: [0, 0, 0.09]
      rotation: !Rpy {{ deg: [90, 0, 90] }}
- add_model:
    name: table
    file: file://{Path.cwd()}/assets/table_extended.sdf
- add_weld:
    parent: world
    child: table::table_link
    X_PC:
      translation: [0.0, 1.0, -0.05]
- add_model:
    name: bin_red
    file: file://{Path.cwd()}/assets/bin_red.sdf
- add_weld:
    parent: world
    child: bin_red::bin_link
    X_PC:
      translation: [{red_bin_pos[0]}, {red_bin_pos[1]}, 0.0]
- add_model:
    name: bin_blue
    file: file://{Path.cwd()}/assets/bin_blue.sdf
- add_weld:
    parent: world
    child: bin_blue::bin_link
    X_PC:
      translation: [{blue_bin_pos[0]}, {blue_bin_pos[1]}, 0.0]{ball_directives}

model_drivers:
  iiwa: !IiwaDriver
    control_mode: position_only
    hand_model_name: wsg
  wsg: !SchunkWsgDriver {{}}
"""

# Camera positions
workspace = np.array([0.0, -0.55, 0.05])
cam_positions = [
    np.array([0.4, -0.15, 0.5]),
    np.array([-0.35, -0.9, 0.25]),
    np.array([-0.5, -0.4, 0.4]),
    np.array([0.5, -0.6, 0.35])
]
cam_transforms = [make_camera_transform(p, workspace) for p in cam_positions]

print(f"Red balls: {len(red_positions)}, Blue balls: {len(blue_positions)}")
print(f"Red bin: {red_bin_pos[:2]}, Blue bin: {blue_bin_pos[:2]}")

Red balls: 2, Blue balls: 2
Red bin: [0.35 1.65], Blue bin: [-0.35  1.65]


In [6]:
# === INITIAL PERCEPTION ===
import builtins  # For Python's built-in min/max

meshcat.Delete()
builder = DiagramBuilder()
station = MakeHardwareStation(LoadScenario(data=scenario_yaml), meshcat=meshcat)
builder.AddSystem(station)
plant = station.GetSubsystemByName("plant")
scene_graph = station.GetSubsystemByName("scene_graph")

# Get initial gripper pose
ctx = station.CreateDefaultContext()
X_WG_init = plant.EvalBodyPoseInWorld(plant.GetMyContextFromRoot(ctx), plant.GetBodyByName("body"))

# Add cameras
scene_graph.AddRenderer("r", MakeRenderEngineVtk(RenderEngineVtkParams()))
cam_cfg = CameraConfig(); cam_cfg.width, cam_cfg.height, cam_cfg.renderer_name = 640, 480, "r"
sensors, infos = [], []
for X in cam_transforms:
    cc, dc = cam_cfg.MakeCameras()
    s = builder.AddSystem(RgbdSensor(scene_graph.world_frame_id(), X, cc, dc))
    builder.Connect(station.GetOutputPort("query_object"), s.query_object_input_port())
    sensors.append(s); infos.append(cc.core().intrinsics())

diagram = builder.Build()
context = diagram.CreateDefaultContext()

def capture_point_cloud():
    """Capture and merge point clouds from all cameras."""
    all_pts, all_cols = [], []
    for s, info, X in zip(sensors, infos, cam_transforms):
        sc = s.GetMyContextFromRoot(context)
        rgb = s.color_image_output_port().Eval(sc).data[:,:,:3].astype(np.uint8)
        depth = s.depth_image_32F_output_port().Eval(sc).data.squeeze()
        pts, cols = depth_to_points(depth, rgb, info)
        all_pts.append(np.array([X @ p for p in pts]))
        all_cols.append(cols)
    pts = np.vstack(all_pts); cols = np.vstack(all_cols)
    above = pts[:, 2] > 0.02
    return pts[above], cols[above]

def detect_all_balls(pts, cols, color_type, expected_count):
    """Detect multiple balls of the same color by iteratively removing detected balls."""
    detected = []
    remaining_pts = pts.copy()
    remaining_cols = cols.copy()
    
    for _ in range(expected_count):
        pos = detect_ball(remaining_pts, remaining_cols, color_type)
        if pos is None:
            break
        detected.append(pos)
        # Remove points near detected ball for next iteration
        dist = np.linalg.norm(remaining_pts - pos, axis=1)
        mask = dist > ball_radius * 1.5
        remaining_pts = remaining_pts[mask]
        remaining_cols = remaining_cols[mask]
    
    return detected

# Capture initial point cloud
pts, cols = capture_point_cloud()

# Detect all balls
detected_red = detect_all_balls(pts, cols, "red", len(red_positions))
detected_blue = detect_all_balls(pts, cols, "blue", len(blue_positions))

print(f"Detected {len(detected_red)} red balls:")
for i, pos in enumerate(detected_red):
    err = builtins.min(np.linalg.norm(pos - rp) for rp in red_positions) * 1000
    print(f"  Red {i}: {pos}, err={err:.1f}mm")

print(f"Detected {len(detected_blue)} blue balls:")
for i, pos in enumerate(detected_blue):
    err = builtins.min(np.linalg.norm(pos - bp) for bp in blue_positions) * 1000
    print(f"  Blue {i}: {pos}, err={err:.1f}mm")

# Combine into a list of (position, color) tuples for sorting
balls_to_sort = [(pos, "red") for pos in detected_red] + [(pos, "blue") for pos in detected_blue]
print(f"\nTotal balls to sort: {len(balls_to_sort)}")



Detected 2 red balls:
  Red 0: [ 0.15028774 -0.59977483  0.05      ], err=0.4mm
  Red 1: [ 1.89902865e-04 -4.99989694e-01  5.00000000e-02], err=0.2mm
Detected 2 blue balls:
  Blue 0: [-0.15044318 -0.55052946  0.05      ], err=0.7mm
  Blue 1: [-0.04977993 -0.64991323  0.05      ], err=0.2mm

Total balls to sort: 4


In [7]:
# === TRAJECTORY GENERATION ===
def make_trajectory(ball_pos, throw_angle, t_offset=0):
    X_grasp = RigidTransform(p=ball_pos) @ RigidTransform(RotationMatrix.MakeXRotation(-np.pi/2), [0, 0, GRASP_HEIGHT])
    X_approach = RigidTransform(p=[0, 0, APPROACH_HEIGHT]) @ X_grasp
    
    # Pick trajectory
    times = [0, 1.0, 1.5, 2.0, 2.5]
    poses = [X_WG_init, X_approach, X_grasp, X_grasp, X_approach]
    pose_traj = PiecewisePose.MakeLinear(times, poses)
    t_samples = np.linspace(0, 2.5, 50)
    q_pick = create_q_knots(plant, [pose_traj.GetPose(t) for t in t_samples])
    
    # Throw trajectory
    ja1 = np.pi/2 + throw_angle - np.pi
    pre = np.array([ja1, 0, 0, 1.9, 0, -1.9, np.pi])
    end = np.array([ja1, 0, 0, 0.4, 0, -0.4, np.pi])
    t_throw = np.linspace(0, THROW_DURATION, 30)
    q_throw = [pre + (t/THROW_DURATION)*(end-pre) for t in t_throw]
    
    # Combine
    t_pick_end = 2.5
    t_full = list(t_samples) + [t_pick_end + 1.5] + list(t_pick_end + 1.5 + t_throw[1:])
    q_full = list(q_pick) + [pre] + q_throw[1:]
    
    # Offset times
    t_full = [t + t_offset for t in t_full]
    q_traj = PiecewisePolynomial.CubicShapePreserving(t_full, np.array(q_full).T)
    
    # Gripper timing
    t_close = 2.0 + t_offset
    t_release = t_pick_end + 1.5 + RELEASE_FRAC * THROW_DURATION + t_offset
    t_end = q_traj.end_time()
    g_traj = PiecewisePolynomial.FirstOrderHold(
        [t_offset, t_close, t_close+0.01, t_release, t_release+0.01, t_end],
        np.array([[GRIPPER_OPEN, GRIPPER_OPEN, GRIPPER_CLOSED, GRIPPER_CLOSED, GRIPPER_OPEN, GRIPPER_OPEN]])
    )
    return q_traj, g_traj, t_end, t_release

def predict_ball_trajectory(q_traj, t_release, plant, num_points=50):
    """Predict parabolic trajectory from release point."""
    # Create a context to evaluate kinematics
    plant_context = plant.CreateDefaultContext()
    
    # Get joint positions and velocities at release
    q_release = q_traj.value(t_release).flatten()
    qd_release = q_traj.derivative(1).value(t_release).flatten()
    
    # Set robot state
    iiwa = plant.GetModelInstanceByName("iiwa")
    plant.SetPositions(plant_context, iiwa, q_release)
    plant.SetVelocities(plant_context, iiwa, qd_release)
    
    # Get gripper pose and spatial velocity
    gripper_body = plant.GetBodyByName("body")
    X_WG = plant.EvalBodyPoseInWorld(plant_context, gripper_body)
    V_WG = plant.EvalBodySpatialVelocityInWorld(plant_context, gripper_body)
    
    # Release position (gripper tip, offset from body frame)
    p_release = X_WG.translation() + X_WG.rotation() @ np.array([0, 0, 0.1])
    v_release = V_WG.translational()
    
    # Predict parabolic trajectory until z < 0.05 (bin height)
    g = np.array([0, 0, -9.81])
    
    # Time to hit z = 0.1 (approximate bin height)
    # z = p0_z + v0_z*t + 0.5*g*t^2
    # Solve for when z = 0.1
    a = 0.5 * g[2]
    b = v_release[2]
    c = p_release[2] - 0.1
    discriminant = b**2 - 4*a*c
    if discriminant > 0:
        t_land = (-b - np.sqrt(discriminant)) / (2*a)
        t_land = builtins.max(0.1, t_land)  # At least some flight time
    else:
        t_land = 1.0  # Default
    
    # Sample trajectory points
    t_flight = np.linspace(0, t_land, num_points)
    points = []
    for t in t_flight:
        p = p_release + v_release * t + 0.5 * g * t**2
        points.append(p)
    
    return np.array(points), p_release, v_release

# Generate trajectories for all balls sequentially
trajectories = []  # List of (q_traj, g_traj, color, t_release, predicted_arc)
t_offset = 0
GAP_BETWEEN_THROWS = 1.0

for ball_pos, color in balls_to_sort:
    throw_angle = RED_THROW_ANGLE if color == "red" else BLUE_THROW_ANGLE
    q_traj, g_traj, t_end, t_release = make_trajectory(ball_pos, throw_angle, t_offset)
    
    # Predict trajectory
    arc_points, p_rel, v_rel = predict_ball_trajectory(q_traj, t_release, plant)
    landing = arc_points[-1]
    
    trajectories.append((q_traj, g_traj, color, t_release, arc_points))
    print(f"{color.capitalize()} ball: {t_offset:.1f}s-{t_end:.1f}s, release v=[{v_rel[0]:.1f}, {v_rel[1]:.1f}, {v_rel[2]:.1f}], land ~[{landing[0]:.2f}, {landing[1]:.2f}]")
    t_offset = t_end + GAP_BETWEEN_THROWS

t_total = t_offset - GAP_BETWEEN_THROWS + 2.0
print(f"\nTotal: {t_total:.1f}s for {len(balls_to_sort)} balls")

Red ball: 0.0s-4.4s, release v=[0.1, 0.2, 2.3], land ~[0.15, 0.56]
Red ball: 5.4s-9.8s, release v=[0.1, 0.2, 2.3], land ~[0.15, 0.56]
Blue ball: 10.8s-15.2s, release v=[-0.1, 0.2, 2.3], land ~[-0.15, 0.56]
Blue ball: 16.2s-20.6s, release v=[-0.1, 0.2, 2.3], land ~[-0.15, 0.56]

Total: 22.6s for 4 balls


In [13]:
# === RUN SIMULATION ===
meshcat.Delete()
builder = DiagramBuilder()
station = MakeHardwareStation(LoadScenario(data=scenario_yaml), meshcat=meshcat)
builder.AddSystem(station)
plant = station.GetSubsystemByName("plant")

# Multi-trajectory source that chains all ball pick-and-throw sequences
class MultiTrajectorySource(LeafSystem):
    def __init__(self, trajectories):
        LeafSystem.__init__(self)
        self.trajectories = trajectories  # List of (q_traj, g_traj, color, t_release, arc)
        self.DeclareVectorOutputPort("q", 7, self.calc_q)
        self.DeclareVectorOutputPort("g", 1, self.calc_g)
    
    def _get_active_trajectory(self, t):
        for q_traj, g_traj, _, _, _ in self.trajectories:
            if t <= q_traj.end_time():
                return q_traj, g_traj
        return self.trajectories[-1][0], self.trajectories[-1][1]
    
    def calc_q(self, context, output):
        t = context.get_time()
        q_traj, _ = self._get_active_trajectory(t)
        t_clamped = builtins.min(t, q_traj.end_time())
        output.SetFromVector(q_traj.value(t_clamped).flatten())
    
    def calc_g(self, context, output):
        t = context.get_time()
        _, g_traj = self._get_active_trajectory(t)
        t_clamped = builtins.min(t, g_traj.end_time())
        output.SetFromVector(g_traj.value(t_clamped).flatten())

traj_source = builder.AddSystem(MultiTrajectorySource(trajectories))
builder.Connect(traj_source.GetOutputPort("q"), station.GetInputPort("iiwa.position"))
builder.Connect(traj_source.GetOutputPort("g"), station.GetInputPort("wsg.position"))

diagram = builder.Build()
sim = Simulator(diagram)
sim.set_target_realtime_rate(1.0)
ctx = sim.get_mutable_context()

# Get all ball bodies
all_ball_bodies = []
for i in range(len(red_positions)):
    body = plant.GetBodyByName("ball", plant.GetModelInstanceByName(f"red_ball_{i}"))
    all_ball_bodies.append((body, "red", red_bin_pos))
for i in range(len(blue_positions)):
    body = plant.GetBodyByName("ball", plant.GetModelInstanceByName(f"blue_ball_{i}"))
    all_ball_bodies.append((body, "blue", blue_bin_pos))

# Match each trajectory to the correct ball body by finding closest initial position
def find_ball_body_for_trajectory(ball_pos, color):
    pctx = plant.GetMyContextFromRoot(ctx)
    best_body = None
    best_dist = float('inf')
    for body, body_color, bin_pos in all_ball_bodies:
        if body_color != color:
            continue
        pos = plant.EvalBodyPoseInWorld(pctx, body).translation()
        dist = np.linalg.norm(pos[:2] - ball_pos[:2])
        if dist < best_dist:
            best_dist = dist
            best_body = body
    return best_body

# Build mapping from trajectory index to ball body
traj_to_ball = []
for ball_pos, color in balls_to_sort:
    body = find_ball_body_for_trajectory(ball_pos, color)
    traj_to_ball.append((body, color))

def compute_arc_from_state(p0, v0):
    """Compute parabolic arc from position and velocity."""
    g = np.array([0, 0, -9.81])
    a = 0.5 * g[2]
    b = v0[2]
    c = p0[2] - 0.1
    disc = b**2 - 4*a*c
    t_land = (-b - np.sqrt(disc)) / (2*a) if disc > 0 else 1.0
    t_land = builtins.max(0.1, builtins.min(2.0, t_land))
    
    t_flight = np.linspace(0, t_land, 50)
    points = np.array([p0 + v0*t + 0.5*g*t**2 for t in t_flight])
    return points

def draw_predicted_arc(arc_points, color, ball_idx):
    rgba = Rgba(0.9, 0.2, 0.2, 0.8) if color == "red" else Rgba(0.2, 0.4, 0.9, 0.8)
    meshcat.SetLine(f"predicted_arc_{ball_idx}", arc_points.T, line_width=4.0, rgba=rgba)

def clear_arc(ball_idx):
    meshcat.Delete(f"predicted_arc_{ball_idx}")

meshcat.StartRecording()
print(f"Running simulation with {len(balls_to_sort)} balls...")

# Track ball states
arc_state = {}

while ctx.get_time() < t_total:
    t = ctx.get_time()
    pctx = plant.GetMyContextFromRoot(ctx)
    
    for idx, (q_traj, g_traj, color, t_release, _) in enumerate(trajectories):
        if idx not in arc_state:
            arc_state[idx] = {"drawn": False, "cleared": False}
        
        state = arc_state[idx]
        ball_body, _ = traj_to_ball[idx]
        
        # Shortly after release, get ball's actual velocity from simulation
        if t >= t_release + 0.02 and not state["drawn"]:
            ball_pos = plant.EvalBodyPoseInWorld(pctx, ball_body).translation()
            ball_vel = plant.EvalBodySpatialVelocityInWorld(pctx, ball_body).translational()
            
            # Only draw if ball is actually moving (released)
            if np.linalg.norm(ball_vel) > 0.5:
                arc_points = compute_arc_from_state(ball_pos, ball_vel)
                draw_predicted_arc(arc_points, color, idx)
                state["drawn"] = True
                print(f"  t={t:.1f}s: {color} #{idx} v=[{ball_vel[0]:.1f}, {ball_vel[1]:.1f}, {ball_vel[2]:.1f}] m/s")
        
        # Clear arc after ball lands
        if state["drawn"] and not state["cleared"] and t >= t_release + 1.5:
            clear_arc(idx)
            state["cleared"] = True
    
    sim.AdvanceTo(t + 0.02)

meshcat.StopRecording()
meshcat.PublishRecording()

for idx in range(len(trajectories)):
    clear_arc(idx)

# Results
print(f"\n=== RESULTS ===")
pctx = plant.GetMyContextFromRoot(ctx)
successes = 0
for body, color, bin_pos in all_ball_bodies:
    final_pos = plant.EvalBodyPoseInWorld(pctx, body).translation()
    dist = np.linalg.norm(final_pos[:2] - bin_pos[:2])
    in_bin = dist < 0.15
    successes += in_bin
    print(f"{color.capitalize()} ball: {final_pos[:2]} -> bin {bin_pos[:2]}, dist={dist*100:.1f}cm, IN={'YES' if in_bin else 'NO'}")

print(f"\n*** {successes}/{len(all_ball_bodies)} BALLS SORTED ***")
if successes == len(all_ball_bodies):
    print("*** PERFECT SORT! ***")

Running simulation with 4 balls...
  t=4.2s: red #0 v=[0.5, 1.7, 2.4] m/s
  t=9.6s: red #1 v=[0.5, 1.7, 2.4] m/s
  t=15.0s: blue #2 v=[-0.4, 1.7, 2.3] m/s
  t=20.4s: blue #3 v=[-0.4, 1.7, 2.3] m/s

=== RESULTS ===
Red ball: [0.39999112 1.61143608] -> bin [0.35 1.65], dist=6.3cm, IN=YES
Red ball: [0.39305297 1.72992085] -> bin [0.35 1.65], dist=9.1cm, IN=YES
Blue ball: [-0.4299535   1.62638691] -> bin [-0.35  1.65], dist=8.3cm, IN=YES
Blue ball: [-0.28359867  1.57116319] -> bin [-0.35  1.65], dist=10.3cm, IN=YES

*** 4/4 BALLS SORTED ***
*** PERFECT SORT! ***
