### Visualize SinMDM for NC HJK `Common Rig`

In [1]:
import mujoco
import numpy as np
import matplotlib.pyplot as plt
from mujoco_parser import MuJoCoParserClass
from util import rpy2r
import pickle
import cv2
np.set_printoptions(precision=2,suppress=True,linewidth=100)
plt.rc('xtick',labelsize=6); plt.rc('ytick',labelsize=6)
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
print ("MuJoCo version:[%s]"%(mujoco.__version__))

MuJoCo version:[2.3.2]




### Parse `scene_common_rig.xml`

In [2]:
xml_path = '../asset/common_rig/scene_common_rig.xml'
env = MuJoCoParserClass(name='Common Rig',rel_xml_path=xml_path,VERBOSE=True)
# Modify the color of body exluding 'world'
for body_name in env.body_names:
    if body_name in ['world']: continue
    body_idx = env.body_names.index(body_name)
    geom_idxs = [idx for idx,val in enumerate(env.model.geom_bodyid) if val==body_idx]
    for geom_idx in geom_idxs:
        env.model.geom(geom_idx).rgba = [0.3,0.3,0.5,0.5]
print ("Done.")

dt:[0.0050] HZ:[200]
n_dof (=nv):[41]
n_geom:[26]
geom_names:['floor', 'base', 'root2spine', 'spine2neck', 'neck2rshoulder', 'neck2lshoulder', 'rshoulder2relbow', 'relbow2rwrist', 'rthumb', 'rpalm', 'lshoulder2lelbow', 'lelbow2lwrist', 'lthumb', 'lpalm', 'head', 'nose', 'base2rpelvis', 'rpelvis2rknee', 'rknee2rankle', 'rankle', 'rfoot', 'base2lpelvis', 'lpelvis2lknee', 'lknee2lankle', 'lankle', 'lfoot']
n_body:[20]
body_names:['world', 'base', 'torso', 'spine', 'neck', 'right_shoulder', 'right_elbow', 'right_hand', 'left_shoulder', 'left_elbow', 'left_hand', 'head', 'right_leg', 'right_pelvis', 'right_knee', 'right_ankle', 'left_leg', 'left_pelvis', 'left_knee', 'left_ankle']
n_joint:[36]
joint_names:['base', 'root1', 'root2', 'root3', 'spine', 'rs1', 'rs2', 'rs3', 're', 'rw1', 'rw2', 'rw3', 'ls1', 'ls2', 'ls3', 'le', 'lw1', 'lw2', 'lw3', 'head1', 'head2', 'head3', 'rp1', 'rp2', 'rp3', 'rk', 'ra1', 'ra2', 'ra3', 'lp1', 'lp2', 'lp3', 'lk', 'la1', 'la2', 'la3']
joint_types:[0 3 3 3 3 3 3

### Load motion data

In [5]:
# data_path = '../nc_hjk/VAAI_Non_M_01_de_01_results.pkl'
data_path_list = []
data_path_list.append('./VAAI_Non_M_02_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_03_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_04_a_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_06_a_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_06_a_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_07_a_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_08_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_09_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_M_10_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_R_01_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_R_02_de_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_R_03_c_me_01_results.pkl')
# data_path_list.append('../nc_hjk/VAAI_Non_R_04_de_01_results.pkl')

for idx in range(len(data_path_list)):
    data_path = data_path_list[idx]

    with open(data_path, 'rb') as f:
        data = pickle.load(f)
    print(data.keys())
    sample_rot = data['sample_rot'].transpose(0,2,1)
    sample_trs = data['sample_trs'].transpose(0,2,1)
    org_rot = data['org_rot']
    org_trs = data['org_trs']
    motion_length = data['motion_length']
    org_rot_tmp = org_rot.copy()
    org_rot[:,18:21] = org_rot_tmp[:,32:35]
    org_rot[:,21:] = org_rot_tmp[:,18:32]
    sample_rot_tmp = sample_rot.copy()
    sample_rot[:,:,18:21] = sample_rot_tmp[:,:,32:35]
    sample_rot[:,:,21:] = sample_rot_tmp[:,:,18:32]
    print('sample rotation shape', sample_rot.shape)
    print('sample org shape', org_rot.shape)
    # Initialize MuJoCo viewer
    env.init_viewer(viewer_title='Common Rig',viewer_width=1200,viewer_height=800,
                    viewer_hide_menus=True)
    env.update_viewer(azimuth=152,distance=3.0,elevation=-30,lookat=[0.02,-0.03,0.8])

    # Set which joints to control
    rev_joint_names = env.rev_joint_names
    joint_idxs_fwd = [env.model.joint(jname).qposadr[0] for jname in rev_joint_names]
    joint_idxs_jac  = [env.model.joint(jname).dofadr[0] for jname in rev_joint_names]
    q_rev_init = env.get_qpos_joints(rev_joint_names)
    n_rev_joint = len(rev_joint_names)
    q = org_rot[0]
    env.reset()
    FIRST_FLAG = True
    frame_idx = 0 
    img_list_org = []
    while env.is_viewer_alive():    
        q = org_rot[frame_idx] 
        # FK
        env.forward(q=q,joint_idxs=joint_idxs_fwd,INCREASE_TICK=True)

        if env.loop_every(tick_every=10):
            # Plot world frame
            env.plot_T(p=np.zeros(3),R=np.eye(3,3),
                       PLOT_AXIS=True,axis_len=0.5,axis_width=0.005)

            # Text information
            env.plot_T(p=org_trs[frame_idx],R=np.eye(3),PLOT_AXIS=False,
                       label='Tick:[%d]'%(env.tick))

            # Plot bodies
            for body_name in env.body_names:
                p,R = env.get_pR_body(body_name=body_name)
                env.plot_T(p=p,R=R,PLOT_AXIS=True,axis_len=0.05,axis_width=0.005)

            # Plot revolute joints with arrow
            for rev_joint_idx,rev_joint_name in zip(env.rev_joint_idxs,env.rev_joint_names):
                axis_joint = env.model.jnt_axis[rev_joint_idx]
                p_joint,R_joint = env.get_pR_joint(joint_name=rev_joint_name)
                axis_world = R_joint@axis_joint
                axis_rgba = np.append(np.eye(3)[:,np.argmax(axis_joint)],0.2)
                axis_len,axis_r = 0.1,0.01
                env.plot_arrow_fr2to(
                    p_fr=p_joint,p_to=p_joint+axis_len*axis_world,
                    r=axis_r,rgba=axis_rgba)

            # Plot contact information
            env.plot_contact_info(h_arrow=0.3,rgba_arrow=[1,0,0,1],
                                  PRINT_CONTACT_BODY=False)
            env.render()
            # Save video
            img = env.grab_image()
            img = cv2.resize(img[:,480:(2880-480)], (1920,2160))
            img_list_org.append(img)
            frame_idx += 1
            if frame_idx == motion_length :
                break
    # Close MuJoCo viewer
    env.close_viewer()
    # Initialize MuJoCo viewer
    env.init_viewer(viewer_title='Common Rig',viewer_width=1200,viewer_height=800,
                    viewer_hide_menus=True)
    env.update_viewer(azimuth=152,distance=3.0,elevation=-30,lookat=[0.02,-0.03,0.8])

    # Set which joints to control
    rev_joint_names = env.rev_joint_names
    joint_idxs_fwd = [env.model.joint(jname).qposadr[0] for jname in rev_joint_names]
    joint_idxs_jac  = [env.model.joint(jname).dofadr[0] for jname in rev_joint_names]
    q_rev_init = env.get_qpos_joints(rev_joint_names)
    n_rev_joint = len(rev_joint_names)
    q = sample_rot[0,0]
    env.reset()
    FIRST_FLAG = True
    frame_idx = 0 
    img_list_sample = []
    while env.is_viewer_alive():    
        q = sample_rot[0,frame_idx] 
        # FK
        env.forward(q=q,joint_idxs=joint_idxs_fwd,INCREASE_TICK=True)

        if env.loop_every(tick_every=2):
            # Plot world frame
            env.plot_T(p=np.zeros(3),R=np.eye(3,3),
                       PLOT_AXIS=True,axis_len=0.5,axis_width=0.005)

            # Text information
            env.plot_T(p=sample_trs[0,frame_idx],R=np.eye(3),PLOT_AXIS=False,
                       label='Tick:[%d]'%(env.tick))

            # Plot bodies
            for body_name in env.body_names:
                p,R = env.get_pR_body(body_name=body_name)
                env.plot_T(p=p,R=R,PLOT_AXIS=True,axis_len=0.05,axis_width=0.005)

            # Plot revolute joints with arrow
            for rev_joint_idx,rev_joint_name in zip(env.rev_joint_idxs,env.rev_joint_names):
                axis_joint = env.model.jnt_axis[rev_joint_idx]
                p_joint,R_joint = env.get_pR_joint(joint_name=rev_joint_name)
                axis_world = R_joint@axis_joint
                axis_rgba = np.append(np.eye(3)[:,np.argmax(axis_joint)],0.2)
                axis_len,axis_r = 0.1,0.01
                env.plot_arrow_fr2to(
                    p_fr=p_joint,p_to=p_joint+axis_len*axis_world,
                    r=axis_r,rgba=axis_rgba)

            # Plot contact information
            env.plot_contact_info(h_arrow=0.3,rgba_arrow=[1,0,0,1],
                                  PRINT_CONTACT_BODY=False)
            env.render()
            # Append img for record video
            img = env.grab_image()
            img = cv2.resize(img[:,480:(2880-480)], (1920,2160))
            img_list_sample.append(img)
            frame_idx += 1
            if frame_idx == motion_length :
                break
    # Close MuJoCo viewer
    env.close_viewer()

    # Write Video
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    fps = 30
    font =  cv2.FONT_HERSHEY_SIMPLEX
    # out_file_name = '../result/VAAI_Non_M_01_de_01_compare.mp4'
    filename = data_path_list[idx].split('/')[-1].split('.')[0]
    out_file_name = '../result/'+filename+'.mp4'
    out = cv2.VideoWriter(out_file_name, fourcc, fps, (3840,2160))
    for img_org, img_sample in zip(img_list_org,img_list_sample) : 
        img_org = cv2.cvtColor(img_org, cv2.COLOR_BGR2RGB)
        img_sample = cv2.cvtColor(img_sample, cv2.COLOR_BGR2RGB)
        img = np.concatenate([img_org, img_sample], axis=1)
        out.write(img)
    out.release()


dict_keys(['sample_rot', 'sample_trs', 'org_rot', 'org_trs', 'motion_length'])
sample rotation shape (5, 192, 35)
sample org shape (192, 35)
