# Ball Trajectory Experiments

Let's test, if we set the position of the ball to be at the throw point, give it the throw velocity, will it go to the catch point?

In [None]:
import logging
from copy import copy
from enum import Enum

import numpy as np
from pydrake.all import (AbstractValue, AddMultibodyPlantSceneGraph, AngleAxis,
                         Concatenate, DiagramBuilder, InputPortIndex,
                         LeafSystem, MeshcatVisualizer, Parser,
                         PiecewisePolynomial, PiecewisePose, PointCloud, RandomGenerator, RigidTransform,
                         RollPitchYaw, Simulator, StartMeshcat,
                         UniformlyRandomRotationMatrix,
                         BsplineTrajectory, Sphere, Rgba,
                         KinematicTrajectoryOptimization, Solve, MinimumDistanceConstraint,
                         PositionConstraint,
                         SpatialVelocity)

from manipulation import FindResource, running_as_notebook
from manipulation.meshcat_utils import AddMeshcatTriad
from manipulation.scenarios import (AddPackagePaths,
                                    MakeManipulationStation, AddIiwa, AddShape, AddWsg, AddMultibodyTriad)

from utils import *

In [None]:
# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at https://68a898be-bd4e-4046-916a-6293aadf6c5d.deepnoteproject.com/7000/


In [None]:
squash_ball_radius = 0.02
p_GB_G = [0, 0.11, 0]
def add_target_objects(X_WThrow, X_WCatch):
    p_ThrowB_W = X_WThrow.rotation() @ p_GB_G
    X_WThrowB_W = RigidTransform(X_WThrow.rotation(), X_WThrow.translation() +  p_ThrowB_W)
    p_CatchB_W = X_WCatch.rotation() @ p_GB_G
    X_WCatchB_W = RigidTransform(X_WCatch.rotation(), X_WCatch.translation() +  p_CatchB_W)
    meshcat.SetObject("throw", Sphere(squash_ball_radius), rgba=Rgba(.9, .1, .1, 1))
    meshcat.SetTransform("throw", X_WThrowB_W)
    meshcat.SetObject("catch", Sphere(squash_ball_radius), rgba=Rgba(.1, .9, .1, 1))
    meshcat.SetTransform("catch", X_WCatchB_W)
    return X_WThrowB_W, X_WCatchB_W

q_Throw = np.array([-0.54, 0.58, 0, -1.79, 0, -0.79, 0])
p_WThrow = [0.57, -0.34, 0.39]
R_WThrow = RollPitchYaw([-np.pi, 0.00, 1.03])

q_Catch = np.array([-0.74, 0.58, 0, -1.79, 0, -0.79, 0])
p_WCatch = [0.49, -0.45, 0.39]
R_WCatch = RollPitchYaw([-np.pi, 0.00, 0.83])

X_WThrow = RigidTransform(R_WThrow, p_WThrow)
X_WCatch = RigidTransform(R_WCatch, p_WCatch)

max_height = 0.5 # Max height ball will reach in meters
X_WThrowB_W, X_WCatchB_W = add_target_objects(X_WThrow, X_WCatch)
V_ThrowB, V_CatchB, t_duration = calculate_ball_vels(X_WThrowB_W.translation(), X_WCatchB_W.translation(), max_height)
t_duration = round(t_duration,3)


In [None]:
def vectors_within_eps(vec_a, vec_b, eps=0.005):
    return (vec_a <= vec_b + eps * np.linalg.norm(vec_b)).all() and (vec_a >= vec_b - eps * np.linalg.norm(vec_b)).all()

class BallValidator(LeafSystem):
    def __init__(self, name, body_index):
        LeafSystem.__init__(self)
        self._body_index = body_index
        self._body_poses_index = self.DeclareAbstractInputPort("body_poses",
                                    AbstractValue.Make([RigidTransform()])).get_index()
        self._body_spatial_velocities_index = self.DeclareAbstractInputPort("body_spatial_velocities",
                                    AbstractValue.Make([SpatialVelocity()])).get_index()
        self.DeclareForcedPublishEvent(self.Publish)

        self.DeclarePeriodicUnrestrictedUpdateEvent(0.001, 0.0, self.Update)

        self.find_pos_match = False
    
    def Update(self, context, state):
        pose = self.get_input_port(self._body_poses_index).Eval(context)[self._body_index]
        velocity = self.get_input_port(self._body_spatial_velocities_index).Eval(context)[self._body_index]

        if (context.get_time() == t_duration): # traj_start_time + 2 * original_t_duration
            print(f"\nWe are at catch time! ({context.get_time()})")
            print(f"actual ball position = {pose.translation()}")
            print(f"actual ball velocity: {velocity.translational()}")
            meshcat.SetObject("actual_catch", Sphere(squash_ball_radius), rgba=Rgba(.9, .1, .9, 1))
            meshcat.SetTransform("actual_catch", pose)
            self.find_pos_match = True

        if self.find_pos_match and vectors_within_eps(pose.translation(), X_WCatchB_W.translation()):
            print(f"\nBall pose matches catch pose at ({context.get_time()})")
            print(f"actual ball velocity: {velocity.translational()}")
            self.find_pos_match = False

        vel_eps = 0.001
        if vectors_within_eps(velocity.translational(), V_CatchB, vel_eps):
            print(f"\nBall velocity matches catch velocity at ({context.get_time()})")
            print(f"actual ball velocity: {velocity.translational()}")
            print(f"actual ball position = {pose.translation()}")
        
        if vectors_within_eps(velocity.translational(), V_ThrowB, vel_eps):
            print(f"\nBall velocity matches throw velocity at ({context.get_time()})")
            print(f"actual ball velocity: {velocity.translational()}")
            print(f"actual ball position = {pose.translation()}")

In [None]:
def ball_trajectory_validation():
    builder = DiagramBuilder()

    model_directives = """
directives:
- add_model:
    name: floor
    file: package://manipulation/floor.sdf
- add_weld:
    parent: world
    child: floor::box
    X_PC:
        translation: [0, 0, -.5]
"""

    for i in range(1 if running_as_notebook else 2):
        model_directives += f"""
- add_model:
    name: ball{i}
    file: package://one_arm_juggling/models/ball.sdf
"""

    station = builder.AddSystem(MyMakeManipulationStation(model_directives, package_xmls=["./package.xml"], time_step=0.001))

    plant = station.GetSubsystemByName("plant")

    ball = plant.GetBodyByName("ball")

    ball_validator = builder.AddSystem(
                                BallValidator(
                                    ball.name(), 
                                    ball.index()))
    builder.Connect(station.GetOutputPort("body_poses"),
                    ball_validator.GetInputPort("body_poses"))
    builder.Connect(station.GetOutputPort("body_spatial_velocities"),
                    ball_validator.GetInputPort("body_spatial_velocities"))

    visualizer = MeshcatVisualizer.AddToBuilder(
        builder, station.GetOutputPort("query_object"), meshcat)
    diagram = builder.Build()

    simulator = Simulator(diagram)
    context = simulator.get_context()

    plant_context = plant.GetMyMutableContextFromRoot(context)

    # V_ThrowB = [-0.144, -0.180, 3.103]
    V_ThrowB = [-0.133, -0.172, 3.087] # opened 0.070, wsg_buff 0.001
    for body_index in plant.GetFloatingBaseBodies():
        tf = RigidTransform(X_WThrowB_W.rotation(), X_WThrowB_W.translation()) 

        plant.SetFreeBodyPose(plant_context,
                              plant.get_body(body_index),
                              tf)

        V = SpatialVelocity(np.array(np.hstack((0,0,0,V_ThrowB)).reshape(6,1), dtype=np.float64))
        plant.SetFreeBodySpatialVelocity(plant.get_body(body_index),
                                        V,
                                        plant_context)

    print(f"planned ball CATCH time: {t_duration}")
    print(f"planned ball CATCH position: {X_WCatchB_W.translation()}")
    print(f"planned ball CATCH velocity: {V_CatchB}")

    simulator.AdvanceTo(0)
    meshcat.Flush()  # Wait for the large object meshes to get to meshcat.
    RecordInterval(0, 1, simulator, context, plant, visualizer)

ball_trajectory_validation()

planned ball CATCH time: 0.639
planned ball CATCH position: [ 0.57117245 -0.52423633  0.39      ]
planned ball CATCH velocity: [-0.14582164 -0.19980315 -3.13155712]

We are at catch time! (0.639)
actual ball position = [ 0.57931589 -0.50653807  0.3566442 ]
actual ball velocity: [-0.133   -0.172   -3.18159]


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=68a898be-bd4e-4046-916a-6293aadf6c5d' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>