# Diff Init Demo

(circa 11/23/22) Code taken from Elaine Pham's work for and the starter code from https://deepnote.com/workspace/Manipulation-ac8201a1-470a-4c77-afd0-2cc45bc229ff/project/18ba0481-7e30-46ba-9e7d-f7b34500d6dc/%2Frobot_painter.ipynb

intermediate function source (circa 11/25/22 https://stackoverflow.com/questions/43594646/how-to-calculate-the-coordinates-of-the-line-between-two-points-in-python)

AddMeshcatTriad (circa 11/28/22 https://github.com/RussTedrake/manipulation/blob/775d20681ddb1c391e4a4d9a30182d02f8212bbc/manipulation/meshcat_utils.py#L299)

Changes to make this demo specific trajectories instead of the Robot Painter trajectory.

In [1]:
import matplotlib.pyplot as plt
import mpld3
import numpy as np
import pandas as pd
from IPython.display import HTML, display
from manipulation import running_as_notebook, FindResource
from manipulation.meshcat_utils import AddMeshcatTriad
from manipulation.scenarios import MakeManipulationStation
from pydrake.all import (AddMultibodyPlantSceneGraph, AngleAxis, BasicVector,
                         ConstantVectorSource, DiagramBuilder,
                         FindResourceOrThrow, Integrator, JacobianWrtVariable,
                         LeafSystem, MeshcatVisualizer,
                         MeshcatVisualizerParams, MultibodyPlant,
                         MultibodyPositionToGeometryPose, Parser,
                         PiecewisePose, Quaternion, RigidTransform,
                         RollPitchYaw, RotationMatrix, SceneGraph, Simulator,
                         StartMeshcat, TrajectorySource, Cylinder, Rgba)
import numpy as np
import torch 
import pandas as pd
import os

In [2]:
# Start the visualizer.
meshcat = StartMeshcat()

INFO:drake:Meshcat listening for connections at https://a569df8b-a69d-4307-8cbd-4f24b608e3b4.deepnoteproject.com/7000/


In [3]:
class IIWA_Painter():
    def __init__(self, traj=None, init_angles=None):
        builder = DiagramBuilder()
        # set up the system of manipulation station
        self.station = MakeManipulationStation(filename=FindResource("models/iiwa_and_wsg.dmd.yaml"))

        builder.AddSystem(self.station)
        self.plant = self.station.GetSubsystemByName("plant")

        # optionally add trajectory source
        if traj is not None:
            traj_V_G = traj.MakeDerivative()
            V_G_source = builder.AddSystem(TrajectorySource(traj_V_G))
            self.controller = builder.AddSystem(
                PseudoInverseController(self.plant))
            builder.Connect(V_G_source.get_output_port(),
                            self.controller.GetInputPort("V_G"))

            self.integrator = builder.AddSystem(Integrator(7))
            builder.Connect(self.controller.get_output_port(),
                            self.integrator.get_input_port())
            builder.Connect(self.integrator.get_output_port(),
                            self.station.GetInputPort("iiwa_position"))
            builder.Connect(
                self.station.GetOutputPort("iiwa_position_measured"),
                self.controller.GetInputPort("iiwa_position"))

        params = MeshcatVisualizerParams()
        params.delete_on_initialization_event = False
        self.visualizer = MeshcatVisualizer.AddToBuilder(
            builder, self.station.GetOutputPort("query_object"), meshcat, params)

        wsg_position = builder.AddSystem(ConstantVectorSource([0.1]))
        builder.Connect(wsg_position.get_output_port(),
                        self.station.GetInputPort("wsg_position"))

        self.diagram = builder.Build()
        self.gripper_frame = self.plant.GetFrameByName('body')
        self.world_frame = self.plant.world_frame()

        context = self.CreateDefaultContext(init_angles)
        self.diagram.Publish(context)

    def visualize_frame(self, name, X_WF, length=0.15, radius=0.006, ground_truth=True):
        """
        visualize imaginary frame that are not attached to existing bodies
        
        Input: 
            name: the name of the frame (str)
            X_WF: a RigidTransform to from frame F to world.
        
        Frames whose names already exist will be overwritten by the new frame
        """
        AddMeshcatTriad(meshcat, "painter/" + name,
                        length=length, radius=radius, X_PT=X_WF,
                        ground_truth=ground_truth)

    def CreateDefaultContext(self, init_angles=None):
        context = self.diagram.CreateDefaultContext()
        plant_context = self.diagram.GetMutableSubsystemContext(
            self.plant, context)
        station_context = self.diagram.GetMutableSubsystemContext(
            self.station, context)

        if init_angles is not None:
            q0 = init_angles
        else:
            # provide initial states
            q0 = np.array([ 1.40666193e-05,  1.56461165e-01, -3.82761069e-05,
                        -1.32296976e+00, -6.29097287e-06,  1.61181157e+00, -2.66900985e-05])
        # set the joint positions of the kuka arm
        iiwa = self.plant.GetModelInstanceByName("iiwa")
        self.plant.SetPositions(plant_context, iiwa, q0)
        self.plant.SetVelocities(plant_context, iiwa, np.zeros(7))
        wsg = self.plant.GetModelInstanceByName("wsg")
        self.plant.SetPositions(plant_context, wsg, [-0.05, 0.05])
        self.plant.SetVelocities(plant_context, wsg, [0, 0])        

        if hasattr(self, 'integrator'):
            self.integrator.set_integral_value(
                self.integrator.GetMyMutableContextFromRoot(context), q0)

        return context


    def get_X_WG(self, context=None):

        if not context:
            context = self.CreateDefaultContext()
        plant_context = self.plant.GetMyMutableContextFromRoot(context)
        X_WG = self.plant.CalcRelativeTransform(
                    plant_context,
                    frame_A=self.world_frame,
                    frame_B=self.gripper_frame)
        return X_WG

    def paint(self, sim_duration=20.0):
        context = self.CreateDefaultContext()
        simulator = Simulator(self.diagram, context)
        simulator.set_target_realtime_rate(1.0)

        duration = sim_duration if running_as_notebook else 0.01
        simulator.AdvanceTo(duration)

class PseudoInverseController(LeafSystem):
    """
    same controller seen in-class
    """
    def __init__(self, plant):
        LeafSystem.__init__(self)
        self._plant = plant
        self._plant_context = plant.CreateDefaultContext()
        self._iiwa = plant.GetModelInstanceByName("iiwa")
        self._G = plant.GetBodyByName("body").body_frame()
        self._W = plant.world_frame()

        self.V_G_port = self.DeclareVectorInputPort("V_G", BasicVector(6))
        self.q_port = self.DeclareVectorInputPort("iiwa_position", BasicVector(7))
        self.DeclareVectorOutputPort("iiwa_velocity", BasicVector(7),
                                     self.CalcOutput)
        self.iiwa_start = plant.GetJointByName("iiwa_joint_1").velocity_start()
        self.iiwa_end = plant.GetJointByName("iiwa_joint_7").velocity_start()

    def CalcOutput(self, context, output):
        V_G = self.V_G_port.Eval(context)
        q = self.q_port.Eval(context)
        self._plant.SetPositions(self._plant_context, self._iiwa, q)
        J_G = self._plant.CalcJacobianSpatialVelocity(
            self._plant_context, JacobianWrtVariable.kV,
            self._G, [0,0,0], self._W, self._W)
        J_G = J_G[:,self.iiwa_start:self.iiwa_end+1] # Only iiwa terms.
        v = np.linalg.pinv(J_G).dot(V_G) #important
        output.SetFromVector(v)

def compose_circular_key_frames(thetas, X_WorldCenter, X_WorldGripper_init, radius):
    """    
    returns: a list of RigidTransforms
    """

    ## this is an template, replace your code below
    # key_frame_poses_in_world = [X_WorldCenter @ X_WorldGripper_init]
    for i in range(len(thetas)):
        rotate = RigidTransform(RotationMatrix.MakeYRotation(thetas[i]));
        translate = RigidTransform([0, 0, -radius])
        this_pose = RigidTransform(X_WorldCenter @ rotate @ translate)
        key_frame_poses_in_world.append(this_pose)
        
    return key_frame_poses_in_world

# check key frames instead of interpolated trajectory
def visualize_key_frames(painter,frame_poses):
    for i, pose in enumerate(frame_poses):
        painter.visualize_frame('frame_{}'.format(i), pose, length=0.05)


In [4]:
### helper function to run and visualize the poses for the not live demo ###

def run_simulation(times, key_frame_poses, init_angles=None):
    """ create a traj from the key_frame_poses and run simulation """
    results = []

    # creates a trajectory out of the poses 
    traj = PiecewisePose.MakeLinear(times, key_frame_poses)
    # visualize the traj
    painter = IIWA_Painter(traj=traj, init_angles=init_angles)
    # visualize_key_frames(painter, key_frame_poses)

    # run sim
    context = painter.CreateDefaultContext(init_angles=init_angles)
    plant_context = painter.plant.GetMyContextFromRoot(context)
    iiwa = painter.plant.GetModelInstanceByName("iiwa")
    simulator = Simulator(painter.diagram, context)
    simulator.set_target_realtime_rate(1.0)
    for timestep in times:
        simulator.AdvanceTo(timestep)

def AddMeshcatTriad(meshcat,
                    path,
                    length=.25,
                    radius=0.01,
                    opacity=1.,
                    X_PT=RigidTransform(),
                    ground_truth=True):
    meshcat.SetTransform(path, X_PT)

    if ground_truth:
        # x-axis
        X_TG = RigidTransform(RotationMatrix.MakeYRotation(np.pi / 2),
                            [length / 2., 0, 0])
        meshcat.SetTransform(path + "/x-axis", X_TG)
        meshcat.SetObject(path + "/x-axis", Cylinder(radius, length),
                        Rgba(1, 0, 0, opacity))

        # y-axis
        X_TG = RigidTransform(RotationMatrix.MakeXRotation(np.pi / 2),
                            [0, length / 2., 0])
        meshcat.SetTransform(path + "/y-axis", X_TG)
        meshcat.SetObject(path + "/y-axis", Cylinder(radius, length),
                        Rgba(0, 1, 0, opacity))

        # z-axis
        X_TG = RigidTransform([0, 0, length / 2.])
        meshcat.SetTransform(path + "/z-axis", X_TG)
        meshcat.SetObject(path + "/z-axis", Cylinder(radius, length),
                        Rgba(0, 0, 1, opacity))
    else:
        # x-axis
        X_TG = RigidTransform(RotationMatrix.MakeYRotation(np.pi / 2),
                            [length / 2., 0, 0])
        meshcat.SetTransform(path + "/x-axis", X_TG)
        meshcat.SetObject(path + "/x-axis", Cylinder(radius, length),
                        Rgba(1, 1, 0, opacity))

        # y-axis
        X_TG = RigidTransform(RotationMatrix.MakeXRotation(np.pi / 2),
                            [0, length / 2., 0])
        meshcat.SetTransform(path + "/y-axis", X_TG)
        meshcat.SetObject(path + "/y-axis", Cylinder(radius, length),
                        Rgba(0, 1, 1, opacity))

        # z-axis
        X_TG = RigidTransform([0, 0, length / 2.])
        meshcat.SetTransform(path + "/z-axis", X_TG)
        meshcat.SetObject(path + "/z-axis", Cylinder(radius, length),
                        Rgba(1, 0, 1, opacity))

In [5]:
### Run this cell and the next to run straight line demo ###

df_diff_init = pd.read_pickle('./data_processed/diff_init/test_out_predictions.pkl')

index = 0 # change this to (0,1,15) view different out of distribution tests

predictions = list(df_diff_init[df_diff_init.trajectory_index == index].prediction)
ground_truth = list(df_diff_init[df_diff_init.trajectory_index == index].label)
input_pose = list(df_diff_init[df_diff_init.trajectory_index == index].end_effector_position)[0]
input_angle = list(df_diff_init[df_diff_init.trajectory_index == index].joint_angle)[0]

painter = IIWA_Painter()
X_WG_init = painter.get_X_WG()
# uncomment to change the starting position of the iiwa arm
X_WG_init = RigidTransform(
                X_WG_init.rotation(),
                np.array(input_pose))
painter = IIWA_Painter(init_angles=input_angle)

In [6]:
### Run this cell to run straight line demo ###

def diff_init_demo(painter, X_WG_init, predictions, ground_truth, input_pose, input_angle):
    num_key_frames = 4
    total_time = 20
    times = np.linspace(0, total_time, num_key_frames+1)
    
    # X_WG_next = RigidTransform(X_WG_init.rotation(),
    #                 np.array(input_pose))
    key_frame_preds = [X_WG_init]
    for i in range(len(predictions)):
        pose_pred = RigidTransform(X_WG_init.rotation(), np.array(predictions[i]))
        pose_ground = RigidTransform(X_WG_init.rotation(), np.array(ground_truth[i]))
        key_frame_preds.append(pose_pred)
        painter.visualize_frame('pred_frame_{}'.format(i), pose_pred, length=0.05, ground_truth=False)
        painter.visualize_frame('ground_frame_{}'.format(i), pose_ground, length=0.05, ground_truth=True)

    run_simulation(times, key_frame_preds, input_angle)

diff_init_demo(painter, X_WG_init, predictions, ground_truth, input_pose, input_angle)

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=a569df8b-a69d-4307-8cbd-4f24b608e3b4' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>