In [None]:
# pip install ipympl==0.9.4
# %matplotlib widget

In [None]:
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial.transform import Rotation as R
from episode_storage import EpisodeReader

In [None]:
def plot_base_pose(base_pose1, base_pose2):
    plt.figure()
    for j, pose in enumerate([base_pose1, base_pose1]):
        x, y, th = pose[:, 0], pose[:, 1], pose[:, 2]
        plt.plot(x, y, label=['obs', 'action'][j])
        for i in range(0, len(x), 2):  # Adjust the step to reduce or increase arrow density
            plt.arrow(x=x[i], y=y[i] + 0.02 * j, dx=0.005 * np.cos(th[i]), dy=0.005 * np.sin(th[i]), head_width=0.0025)
    plt.axis('equal')
    plt.title('base_pose')
    plt.legend()
    plt.show()

In [None]:
def plot_3d_pose(pos1, quat1, pos2, quat2):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for k, (pos, quat) in enumerate([(pos1, quat1), (pos2, quat2)]):
        x, y, z = pos[:,0], pos[:,1], pos[:,2]
        ax.plot(x, y, z, label=['obs', 'action'][k])
        matrix = R.from_quat(quat).as_matrix()
        for i in range(len(pos)):
            if i % 5 == 0:
                start = pos[i]
                start[1] += 0.1 * k
                for j in range(3):
                    end = start + 0.02 * matrix[i, :, j]
                    ax.quiver(*start, *(end - start), color=['r', 'g', 'b'][j], arrow_length_ratio=0.1)
        ax.text(x[0], y[0], z[0], 'Start')
        ax.text(x[-1], y[-1], z[-1], 'End')
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('z')
    ax.set_aspect('equal')
    plt.title('arm_pose')
    plt.legend()
    plt.show()

In [None]:
def plot_gripper_pos(gripper_pos1, gripper_pos2):
    plt.figure()
    plt.plot(gripper_pos1, label='obs')
    plt.plot(gripper_pos2, label='action')
    plt.title('gripper_pos')
    plt.legend()
    plt.show()

In [None]:
def plot_episode(episode_dir):
    reader = EpisodeReader(episode_dir)
    observations = reader.observations
    actions = reader.actions

    # base_pose
    plot_base_pose(np.array([obs['base_pose'] for obs in observations]), np.array([action['base_pose'] for action in actions]))

    # arm_pos and arm_quat
    plot_3d_pose(np.array([obs['arm_pos'] for obs in observations]), np.array([obs['arm_quat'] for obs in observations]),
                 np.array([action['arm_pos'] for action in actions]), np.array([action['arm_quat'] for action in actions]))

    # gripper_pos
    plot_gripper_pos(np.array([obs['gripper_pos'] for obs in observations]), np.array([action['gripper_pos'] for action in actions]))

In [None]:
def main(input_dir):
    episode_dirs = sorted([child for child in Path(input_dir).iterdir() if child.is_dir()])
    for episode_dir in episode_dirs:
        plot_episode(episode_dir)

main('data/sim-v1')