# Playback with  `MPC dataset` and `Ground Truth`

In [1]:
import torch
import mujoco
import numpy as np
import matplotlib.pyplot as plt
from mujoco_parser import MuJoCoParserClass
from util import rpy2r
import pickle
import mediapy as media
from util import rpy2r,MultiSliderClass,create_folder_if_not_exists
import cv2,glob,os

from policy import GaussianPolicy

np.set_printoptions(precision=2,suppress=True,linewidth=100)
plt.rc('xtick',labelsize=6); plt.rc('ytick',labelsize=6)
%config InlineBackend.figure_format = 'retina'
# %matplotlib inline
print ("MuJoCo version:[%s]"%(mujoco.__version__))

MuJoCo version:[2.3.7]


### Load `Ground Truth`

In [2]:
with open(file='../data/smplrig_walk_optimized_recon_local.pkl', mode='rb') as f:
    GT = pickle.load(f)

print(GT.keys())
# GT Stand 

dict_keys(['length', 'p_root', 'R_root', 'v_root', 'w_root', 'qpos', 'qvel', 'xpos', 'local_xpos', 'rotation'])


### Load `MPC dataset`

In [4]:
with open(file='../data/SMPL_MPC_motion_240502.pkl', mode='rb') as f:
    dataset = pickle.load(f)

horizon = 200

action_batch = np.array(dataset['action'])
qpos_batch = np.array(dataset['qpos'])
qvel_batch = np.array(dataset['qvel'])

L = action_batch.shape[0] # actual dataset size
size = L

action_batch = action_batch.reshape(L, horizon, -1)[:size, :, :]#.reshape(size*horizon, -1)
qpos_batch = qpos_batch.reshape(L, horizon, -1)[:size, :, :]#.reshape(size*horizon, -1)
qvel_batch = qvel_batch.reshape(L, horizon, -1)[:size, :, :]#.reshape(size*horizon, -1)

corrupted_idx   = np.where(action_batch == None)[0]
size -= len(corrupted_idx)
action_batch    = np.ma.array(action_batch, mask=False)
action_batch.mask[corrupted_idx] = True
action_batch    = action_batch.compressed().reshape(-1,horizon,action_batch.shape[-1]).reshape(-1,action_batch.shape[-1])

qpos_batch      = np.ma.array(qpos_batch, mask=False)
qpos_batch.mask[corrupted_idx] = True
qpos_batch      = qpos_batch.compressed().reshape(-1,horizon,qpos_batch.shape[-1]).reshape(-1,qpos_batch.shape[-1])

qvel_batch      = np.ma.array(qvel_batch, mask=False)
qvel_batch.mask[corrupted_idx] = True
qvel_batch      = qvel_batch.compressed().reshape(-1,horizon,qvel_batch.shape[-1]).reshape(-1,qvel_batch.shape[-1])


# action_batch = action_batch.reshape(int(action_batch.shape[0] * horizon), -1)
# qpos_batch = qpos_batch.reshape(int(qpos_batch.shape[0] * horizon) , -1)
# qvel_batch = qvel_batch.reshape(int(qvel_batch.shape[0] * horizon), -1)

del dataset

qpos_batch = qpos_batch[:, :]
# qpos_batch = torch.cat((qpos_batch[:-2, :],qpos_batch[1:-1, :],qpos_batch[2:, :]),dim=1)
# qvel_batch = torch.cat((qvel_batch[:-2, :],qvel_batch[1:-1, :],qvel_batch[2:, :]),dim=1)

obs_batch = np.concatenate((qpos_batch, qvel_batch), axis=1)

# action_batch = action_batch[2:]
obs_dim = obs_batch.shape[1]
action_dim = action_batch.shape[1]
hidden_dim = 512

print("obs : ", obs_batch.shape)
print("action : ", action_batch.shape)


obs :  (39800, 87)
action :  (39800, 37)


In [5]:
xml_path = '../asset/smpl_rig/scene_smpl_rig.xml'
env = MuJoCoParserClass(name='smplrig',rel_xml_path=xml_path,VERBOSE=True)
# Modify the color of body exluding 'world'
for body_name in env.body_names:
    if body_name in ['world']: continue
    body_idx = env.body_names.index(body_name)
    geom_idxs = [idx for idx,val in enumerate(env.model.geom_bodyid) if val==body_idx]
    for geom_idx in geom_idxs:
        env.model.geom(geom_idx).rgba = [0.3,0.3,0.5,0.5]
print ("Done.")

dt:[0.0083] HZ:[120]
n_dof (=nv):[43]
n_geom:[67]
geom_names:['floor', 'base', 'base2lpelvis', 'base2rpelvis', 'base2spine1', 'spine2spine', 'spine2spine2', 'spine2spine3', 'spine2lcollar', 'spine2rcollar', 'neck2head', 'head', 'nose', 'lcollar2lshoulder', 'lshoulder2lelbow', 'lelbow2lwrist', 'lwrist2lindex1', 'lwrist2lmiddle1', 'lwrist2lpinky1', 'lwrist2lring1', 'lwrist2lthumb1', 'lindex1-lindex2', 'lindex2-lindex3', 'lindex3-lindextip', 'lmiddle1-lmiddle2', 'lmiddle2-lmiddle3', 'lmiddl3-lmiddletip', 'lring1-lring2', 'lring2-lring3', 'lring3-lringtip', 'lpinky1-lpinky2', 'lpinky2-lpinky3', 'lpinky3-lpinkytip', 'lthumb1-lthumb2', 'lthumb2-lthumb3', 'lthumb3-lthumbtip', 'rcollar2rshoulder', 'rshoulder2relbow', 'relbow2rwrist', 'rwrist2rindex1', 'rwrist2rmiddle1', 'rwrist2rpinky1', 'rwrist2rring1', 'rwrist2rthumb1', 'rindex1-rindex2', 'rindex2-rindex3', 'rindex3-rindextip', 'rmiddle1-rmiddle2', 'rmiddle2-rmiddle3', 'rmiddle3-rmiddletip', 'rring1-rring2', 'rring2-rring3', 'rring3-rringtip

In [6]:
env.model.body_mass

array([0.  , 2.59, 0.83, 0.95, 0.56, 3.4 , 0.69, 2.18, 0.75, 1.58, 1.52, 0.06, 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.75, 1.58, 1.52, 0.06, 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  ,
       0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 0.  , 2.16, 2.28, 0.72, 2.16,
       2.28, 0.72])

In [7]:
# Set which joints to control
ctrl_joint_names = env.ctrl_joint_names # <==
joint_idxs_fwd = env.get_idxs_fwd(joint_names=ctrl_joint_names)
joint_idxs_jac = env.get_idxs_jac(joint_names=ctrl_joint_names)
q_ctrl_init = env.get_qpos_joints(ctrl_joint_names)
n_ctrl_joint = len(ctrl_joint_names)

In [8]:
PLOT_EVERY  = 4
SAVE_VID    = False
print ("PLOT_EVERY:[%d] SAVE_VID:[%d]"%(PLOT_EVERY,SAVE_VID))

PLOT_EVERY:[4] SAVE_VID:[0]


In [9]:
GT['xpos'].shape

(315, 62, 3)

In [12]:
# Initialize MuJoCo viewer
env.init_viewer(viewer_title='SMPL',viewer_width=1200,viewer_height=800,
                viewer_hide_menus=False,FONTSCALE_VALUE=300)
env.update_viewer(azimuth=152,distance=3.0,elevation=-20,lookat=[0.02,-0.03,1.2])
env.reset()

for t in range(GT['length']):
    for h in range(horizon):
        qpos = qpos_batch[horizon*t+h]
        env.forward(qpos)

        xpos = env.data.xpos[:-1] # task space position

    
        if env.loop_every(tick_every=1):
            # Plot world frame
            env.plot_T(p=np.zeros(3),R=np.eye(3,3),
                    PLOT_AXIS=True,axis_len=0.5,axis_width=0.005)
            env.plot_T(p=np.array([0,0,0.5]),R=np.eye(3,3),
                    PLOT_AXIS=False,label="tick:[%d]"%(env.tick))

            # Stand
            for p_GT in GT['xpos'][int(t+h)]:
                env.plot_sphere(p=p_GT,r=0.02,rgba=[1,0.2,0.2,1])
            
            env.render()

            if SAVE_VID:
                # Save iamges
                png_path = '../temp/smplrig_MPC_%05d.png'%(env.tick)
                create_folder_if_not_exists(png_path)
                image = cv2.cvtColor(env.grab_image(),cv2.COLOR_RGB2BGR)
                cv2.imwrite(png_path,image)

Pressed ESC
Quitting.


Exception: GLFW window does not exist but you tried to render.

In [None]:
if SAVE_VID:
    # Show video
    png_paths = sorted(glob.glob('../temp/smplrig_MPC_*.png'))
    frames = []
    for png_path in png_paths:
        img_bgr = cv2.imread(png_path)
        img_rgb = cv2.cvtColor(img_bgr,cv2.COLOR_BGR2RGB)
        frames.append(img_rgb)
    media.show_video(images=frames,fps=int(1/env.dt/PLOT_EVERY),width=500)
    # Save video to mp4
    vid_path = '../vid/kin_chain.mp4'
    create_folder_if_not_exists(vid_path)
    media.write_video(images=frames,fps=int(1/env.dt/PLOT_EVERY),path=vid_path)
    print ("[%s] saved."%(vid_path))

0
This browser does not support the video tag.


ValueError: Image dimensions (1387, 2560) do not match those of the initialized video (800, 1200).

In [10]:
# Remove existing images
png_paths = sorted(glob.glob('../temp/smplrig_MPC_*.png'))
for png_path in png_paths:
    os.remove(png_path)
print ("[%d] images removed."%(len(png_paths)))

[0] images removed.
