# Pretrain Policy with  `MPC dataset`

In [1]:
import torch
import mujoco
import numpy as np
import matplotlib.pyplot as plt
from mujoco_parser import MuJoCoParserClass
from util import rpy2r

from policy import GaussianPolicy

np.set_printoptions(precision=2,suppress=True,linewidth=100)
plt.rc('xtick',labelsize=6); plt.rc('ytick',labelsize=6)
%config InlineBackend.figure_format = 'retina'
%matplotlib inline
print ("MuJoCo version:[%s]"%(mujoco.__version__))


  from .autonotebook import tqdm as notebook_tqdm


MuJoCo version:[2.3.6]


### Load  `Pretrained Model`

In [2]:
checkpoint = torch.load("pretrained.pth")
obs_dim = 99
action_dim = 26
hidden_dim = 256
policy = GaussianPolicy(
    input_dim=obs_dim,
    output_dim=action_dim,
    hidden_dim=hidden_dim,
    is_deterministic=True,)
policy.load_state_dict(checkpoint)

<All keys matched successfully>

In [3]:
xml_path = '../asset/smpl_rig/smpl_rig_rilab_mpc.xml'
env = MuJoCoParserClass(name='smplrig',rel_xml_path=xml_path,VERBOSE=True)
# Modify the color of body exluding 'world'
for body_name in env.body_names:
    if body_name in ['world']: continue
    body_idx = env.body_names.index(body_name)
    geom_idxs = [idx for idx,val in enumerate(env.model.geom_bodyid) if val==body_idx]
    for geom_idx in geom_idxs:
        env.model.geom(geom_idx).rgba = [0.3,0.3,0.5,0.5]
print ("Done.")

dt:[0.0010] HZ:[1000]
n_dof (=nv):[49]
n_geom:[31]
geom_names:['floor', 'base', 'base2lpelvis', 'base2rpelvis', 'base2spine1', 'lhip2lknee', 'lknee2lankle', 'foot1_left', 'foot2_left', 'rhip2rknee', 'rknee2rankle', 'foot1_right', 'foot2_right', 'spine2spine', 'spine2spine2', 'spine2spine3', 'spine2lcollar', 'spine2rcollar', 'neck2head', 'head', 'nose', 'lcollar2lshoulder', 'lshoulder2lelbow', 'lelbow2lwrist', 'lthumb', 'lpalm', 'rcollar2rshoulder', 'rshoulder2relbow', 'relbow2rwrist', 'rthumb', 'rpalm']
n_body:[22]
body_names:['world', 'base', 'pelvis', 'left_hip', 'left_knee', 'left_ankle', 'right_hip', 'right_knee', 'right_ankle', 'spine1', 'spine2', 'spine3', 'neck', 'head', 'left_collar', 'left_shoulder', 'left_elbow', 'left_wrist', 'right_collar', 'right_shoulder', 'right_elbow', 'right_wrist']
n_joint:[44]
joint_names:['base', 'pelvis1', 'pelvis2', 'pelvis3', 'l_hip1', 'l_hip2', 'l_hip3', 'l_knee', 'l_ankle1', 'l_ankle2', 'l_ankle3', 'r_hip1', 'r_hip2', 'r_hip3', 'r_knee', 'r_ank

In [4]:
# Set which joints to control
ctrl_joint_names = env.ctrl_names # <==
joint_idxs_fwd = env.get_idxs_fwd(joint_names=ctrl_joint_names)
joint_idxs_jac = env.get_idxs_jac(joint_names=ctrl_joint_names)
q_ctrl_init = env.get_qpos_joints(ctrl_joint_names)
n_ctrl_joint = len(ctrl_joint_names)

# import pandas as pd
# pkl_path = '../data/F01A0V1.pkl'
# pd.read_pickle(pkl_path)

import pickle
pkl_path = '../data/M02F4V1.pkl'
with open(pkl_path,'rb') as f:
    data = pickle.load(f)

print(data['length'])
print(data['root'].shape)
print(data['qpos'].shape)

1136
(1136, 3)
(1136, 50)


In [10]:
L = data['length']

# Initialize MuJoCo viewer
env.init_viewer(viewer_title='SMPL',viewer_width=1200,viewer_height=800,
                viewer_hide_menus=True)
env.update_viewer(azimuth=152,distance=3.0,elevation=-20,lookat=[0.02,-0.03,1.2])
env.reset()
tick = 0
q = data['qpos'][tick,:]
p_root = data['root'][tick,:]
# tick = min(tick+1,L-1)
# if tick==(L-1): tick = 0
# else: tick = tick + 1
env.set_p_root(root_name='base',p=p_root)
env.forward(q=q,INCREASE_TICK=True)
policy.eval()

while env.is_viewer_alive():
    qpos = env.data.qpos
    qvel = env.data.qvel
    state = np.concatenate((qpos, qvel))
    state = torch.Tensor(state)
    action, _ = policy(state)
    
    env.step(action.detach().numpy())
    print(env.data.ctrl[:])
    if env.loop_every(tick_every=1):
        # Plot world frame
        env.plot_T(p=np.zeros(3),R=np.eye(3,3),
                   PLOT_AXIS=True,axis_len=0.5,axis_width=0.005)
        env.plot_T(p=np.array([0,0,0.5]),R=np.eye(3,3),
                   PLOT_AXIS=False,label="tick:[%d]"%(tick))
        # Plot foot
        # env.plot_geom_T(geom_name='rfoot',axis_len=0.3)
        # env.plot_geom_T(geom_name='lfoot',axis_len=0.3)
        # Plot revolute joints with arrow
        # env.plot_joint_axis(axis_len=0.1,axis_r=0.01)    
        env.render()
# Close MuJoCo viewer
env.close_viewer()
print ("Done.")

[ 0.03  0.02 -0.01  0.1   0.11  0.05  0.22 -0.12  0.16 -0.15 -0.21  0.07  0.05  0.13 -0.06  0.08
 -0.01 -0.11  0.09  0.01  0.1  -0.04  0.07 -0.17  0.02 -0.03]
[ 0.03  0.02 -0.01  0.1   0.12  0.07  0.21 -0.13  0.16 -0.15 -0.2   0.08  0.04  0.14 -0.05  0.08
 -0.02 -0.11  0.09  0.01  0.1  -0.04  0.08 -0.17  0.03 -0.04]
[ 0.04  0.01 -0.01  0.1   0.12  0.09  0.21 -0.13  0.16 -0.15 -0.19  0.09  0.04  0.14 -0.04  0.07
 -0.03 -0.1   0.09  0.02  0.1  -0.05  0.09 -0.17  0.03 -0.05]
[ 0.04  0.01 -0.    0.1   0.12  0.1   0.21 -0.14  0.16 -0.15 -0.18  0.09  0.04  0.15 -0.03  0.06
 -0.03 -0.1   0.09  0.02  0.1  -0.05  0.1  -0.17  0.04 -0.06]
[ 0.04  0.01  0.    0.09  0.12  0.11  0.21 -0.14  0.15 -0.15 -0.17  0.1   0.03  0.15 -0.02  0.06
 -0.04 -0.09  0.09  0.02  0.1  -0.06  0.11 -0.16  0.04 -0.07]
[ 0.05  0.01  0.    0.09  0.12  0.13  0.21 -0.15  0.15 -0.16 -0.17  0.11  0.03  0.16 -0.01  0.06
 -0.05 -0.08  0.09  0.02  0.1  -0.07  0.11 -0.16  0.04 -0.07]
[ 0.05  0.01  0.01  0.08  0.12  0.14  0.21 -0.