In [1]:
import numpy as np
import pyrender
import smplx
import torch
import trimesh
import matplotlib.pyplot as plt
import os
import cv2
os.environ["PYOPENGL_PLATFORM"] = "egl"

def normalize_angle(x):
    return torch.atan2(torch.sin(x), torch.cos(x))

def quat_to_angle_axis(q):
    # type: (Tensor) -> Tuple[Tensor, Tensor]
    # computes axis-angle representation from quaternion q
    # q must be normalized
    min_theta = 1e-5
    qx, qy, qz, qw = 0, 1, 2, 3

    sin_theta = torch.sqrt(1 - q[..., qw] * q[..., qw])
    angle = 2 * torch.acos(q[..., qw])
    angle = normalize_angle(angle)
    sin_theta_expand = sin_theta.unsqueeze(-1)
    axis = q[..., qx:qw] / sin_theta_expand

    mask = sin_theta > min_theta
    default_axis = torch.zeros_like(axis)
    default_axis[..., -1] = 1

    angle = torch.where(mask, angle, torch.zeros_like(angle))
    mask_expand = mask.unsqueeze(-1)
    axis = torch.where(mask_expand, axis, default_axis)
    return angle, axis

def angle_axis_to_exp_map(angle, axis):
    # type: (Tensor, Tensor) -> Tensor
    # compute exponential map from axis-angle
    angle_expand = angle.unsqueeze(-1)
    exp_map = angle_expand * axis
    return exp_map

def quat_to_exp_map(q):
    # type: (Tensor) -> Tensor
    # compute exponential map from quaternion
    # q must be normalized
    angle, axis = quat_to_angle_axis(q)
    exp_map = angle_axis_to_exp_map(angle, axis)
    return exp_map

def obs_visualize(obs):
    '''
    obs numpy shape (1, 4+63) root_rot + dof_pos

    #save obs for experiment
    obs_save = torch.cat((self._humanoid_root_states[:1, 3:7], self._dof_pos[:1]), dim=-1)
    obs_save = obs_save.detach().cpu().numpy()
    import datetime
    time_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
    obs_folder = os.path.abspath("obs/new")
    # Create output folder if needed
    os.makedirs(obs_folder, exist_ok=True)
    savepath = os.path.join(obs_folder, time_str+".npy")
    np.save(savepath, obs_save)
    
    '''
    DOF_BODY_IDS = [3, 6, 9, 13, 16, 18, 20, 12, 15, 14, 17, 19, 21, 2, 5, 8, 11, 1, 4, 7, 10]
    body_model = 'smpl'
    body_model_path = './body_model/smpl_model/models' #share drive
    '''
    obs = torch.from_numpy(obs).view(-1, 1, 63)
    
    rot = torch.zeros(list(obs.shape[:-2]) + [24, 3])
    rot[..., DOF_BODY_IDS, :] = obs[..., 4:].view(list(obs.shape[:-2]) + [21, 3])
    rot[..., 0, :] = quat_to_exp_map(obs[:, 0, :4])
    '''
    obs = torch.from_numpy(obs)
    pred24_4 = obs[..., 3:].view(-1, 4)
    rot = quat_to_exp_map(pred24_4).view(-1, 24, 3)
    
    body_model = smplx.create(model_path=body_model_path, model_type=body_model)
    faces = body_model.faces

    vertices = body_model(global_orient=rot[..., :1, :], body_pose=rot[..., 1:, :]).vertices[0].detach().numpy()
    # vertices = body_model().vertices[0].detach().numpy()

    original_mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    original_mesh.export('obsvistest.ply')
    mesh = pyrender.Mesh.from_trimesh(original_mesh)
    scene = pyrender.Scene(bg_color=[0, 0, 0, 0], ambient_light=(0.3, 0.3, 0.3))
    # scene = pyrender.Scene()
    scene.add(mesh, 'mesh')

    # add camera pose
    camera_pose = np.array([[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 3],
                            [0, 0, 0, 1]])
    # use this to make it to center
    camera = pyrender.camera.PerspectiveCamera(yfov=1)
    scene.add(camera, pose=camera_pose)

    # Get the lights from the viewer
    light = pyrender.SpotLight(color=np.ones(3), intensity=3.0, innerConeAngle=np.pi/3.0, outerConeAngle=np.pi/3.0)
    scene.add(light, pose=camera_pose)

    # offscreen render
    r = pyrender.OffscreenRenderer(viewport_width=512, viewport_height=512)
    color, depth = r.render(scene, flags=pyrender.RenderFlags.RGBA)
    # plt.figure(figsize=(8, 8))
    # plt.imshow(color[:, :, 0:3])
    # plt.show()
    cv2.imwrite('obsvistest.png', color[:, :, 0:3])
    return color[:, :, 0:3]


In [3]:
data_path = '/home/wenbin/kangning/motion/RIOT/decision-transformer/gym/motion_train_dataset/004882c9-da8d-4a76-90a8-039d9a690d73.npy'
obs = np.load(data_path, allow_pickle=True)
od_dict = obs.item()
rotation_traj = od_dict["rotation"]["arr"]
root_traj = od_dict["root_translation"]["arr"]
traj_len = len(rotation_traj)
rotation_traj = rotation_traj.reshape(traj_len, -1)
obs_99dim = np.concatenate((rotation_traj[0], root_traj[0]), axis=0)
#print(traj.shape)
img = obs_visualize(obs_99dim)

ValueError: Unknown model type models, exiting!