In [None]:
import os, sys
os.environ["CUDA_VISIBLE_DEVICES"] = "6,"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
from pathlib import Path
from omegaconf import OmegaConf
import math
import numpy as np
import torch

import smplx, mano, trimesh
from pytorch3d.structures import Meshes
from psbody.mesh import MeshViewers, Mesh
from psbody.mesh.colors import name_to_rgb
import meshplot as mp

from pct.two_stage import concatMap, Point2Hand
from transformers_model.motion_model import pretrain_actor_rel
from tools.objectmodel import ObjectModel
# from tools.meshviewer import Mesh
from tools.utils import to_cpu, to_tensor, trans_global2loc_rh_wrist
from tools.utils import aa2rotmat, rotmat2aa, rotmat2d6, d62rotmat, get_relation_map_new
from tools.model_utils import parms_6D2full_rh, full2bone_pretrain
from tools.optim_union import CoopOptim

device = "cuda:0"

In [None]:
torch.manual_seed(3407)
torch.cuda.empty_cache()

rhand_model = mano.load(
    model_path='/data/3D_dataset/smpl_related/models/mano',
    model_type='mano', num_pca_comps=45,
    use_pca=False, batch_size=1,
    flat_hand_mean=True).to(device)

body_model = smplx.SMPLXLayer(
    model_path='/data/3D_dataset/smpl_related/models/smplx',
    gender='male', num_pca_comps=45,
    use_pca=False, batch_size=1,
    flat_hand_mean=True).to(device)

object_model = ObjectModel().to(device)

# Create the network
relation_map = get_relation_map_new()

contact_cfg = OmegaConf.load("../model/Hnet/hand2contact.yaml")
contact_network = concatMap(**contact_cfg.network.coop_model).eval().to(device)
contact_network.load_state_dict(torch.load("../model/Hnet/hand2contact_pt.pt", map_location='cpu'))

hand_cfg = OmegaConf.load("../model/Hnet/contact2hand.yaml")
hand_network = Point2Hand(**hand_cfg.network.coop_model).eval().to(device)
hand_network.load_state_dict(torch.load("../model/Hnet/contact2hand_pt.pt", map_location='cpu'))

body_cfg = OmegaConf.load("../model/Bnet/body_net.yaml")
body_network = pretrain_actor_rel(relation_map=relation_map, **body_cfg.network.coop_model).to(device)
body_network.load_state_dict(torch.load("../model/Bnet/body_net.pt", map_location='cpu'))

rh_ids_sampled = torch.from_numpy(np.load('./consts/valid_rh_idx_99.npy'))
rh_verts_ids = to_tensor(np.load('./consts/MANO_SMPLX_vertex_ids.pkl',allow_pickle=True)['right_hand'], dtype=torch.long)

In [None]:
# from data.dataloader_union import LoadData
# ds_test = LoadData(body_cfg.datasets, split_name='test')
# batch = ds_test[0]
# for k,v in batch.items():
#     batch[k] = v.unsqueeze(0).to(device)

In [None]:
obj_path = "/shared/3D/GrabFusion/data/dex-ycb-models/025_mug/textured_simple.obj"
obj_mesh = trimesh.load(obj_path)
obj_verts = torch.from_numpy(obj_mesh.vertices)
obj_m = ObjectModel(v_template=obj_verts).to(device)

body_fit_smplx = CoopOptim(
    sbj_model=body_model, rh_model=rhand_model, obj_model=obj_m,
    cfg=body_cfg, device=device, verbose=True
)

# grnd_mesh, cage, axis_l = get_ground()

In [None]:
x, y, z = 0.4, -0.4, 1.4

# transl = torch.tensor([[-0.0255,  1.3412, -0.5916]], dtype=torch.float32, device=device)
transl_obj = torch.tensor([[y, z, x]], dtype=torch.float32, device=device)
# global_orient_obj = torch.tensor([[1.1428, 1.3750, 1.4559]], device=device)
global_orient_obj = torch.tensor([[90.0, 0.0, 0.0]], device=device)
# global_orient_obj_RH = torch.tensor([[-0.1865,  0.1269, -4.4412]], device=device)
global_orient_obj_RH = torch.tensor([[0.0, 0.0, 0.0]], device=device)

# batch_ = batch.copy()
# batch_['transl_obj'][0] = torch.tensor([y, z, x])

r = math.atan(y/x) if x >= 0 else math.pi + math.atan(y/x)
reserved_r = - r if x >= 0 else 2 * math.pi - r
rotmat = aa2rotmat(global_orient_obj_RH)
new_rotmat = torch.matmul(
    rotmat,
    torch.Tensor([
        [math.cos(r),-math.sin(r),0],
        [math.sin(r),math.cos(r),0],
        [0,0,1]
    ]).to(torch.float32).to(device)
).reshape(-1,3,3)
# batch_['global_orient_obj_RH'] = rotmat2aa(new_rotmat)

obj_verts = obj_m(transl=torch.zeros((1,3), device=device), global_orient=new_rotmat, pose2rot=False).vertices[0]
centroid = torch.mean(obj_verts, dim=0) # bs , N
pc_center = obj_verts - centroid

mesh = Meshes(verts=pc_center[None,:], faces=torch.from_numpy(obj_mesh.faces[None,:].astype(np.int64)).to(device))
normal = mesh.verts_normals_packed().view(-1, 3)
simple_vertices_ids = torch.from_numpy(np.random.choice(obj_verts.shape[0], 2048, replace=False))
input_points = pc_center[simple_vertices_ids].unsqueeze(0).to(device)
input_normal = normal[simple_vertices_ids].unsqueeze(0).to(device)
input_height = ((torch.Tensor([z]).to(device) - centroid[2]) / 2).unsqueeze(0).to(device)
# batch_['verts'] = input_points

obj_verts = obj_m(transl=transl_obj, global_orient=global_orient_obj).vertices

In [None]:
with torch.no_grad():
    torch.manual_seed(3407)
    torch.cuda.empty_cache()   
    
    contact_map = contact_network.infer(
        points=input_points,
        normal=input_normal,
        obj_height=input_height
    )['contact_map']
    pre_map = torch.softmax(contact_map,dim=-1)
    pre_map[:,0] -= 0.2
    pre_map = pre_map.argmax(-1).view(-1, 2048).long()
    # batch_['contact_map_obj'] = pre_map

    net_output = hand_network(
        # betas=batch['betas_rh'],
        betas=torch.zeros((1,10), dtype=torch.float32, device=device),
        contact_map=pre_map,
        points=input_points,
        normal=input_normal,
        obj_height=input_height
    )
    pose, trans = net_output['pose'], net_output['trans']
    cnet_params = parms_6D2full_rh(pose, trans, d62rot=True)
    pose_aa = rotmat2aa(cnet_params['fullpose_rotmat']).view(1, -1)
    cnet_params['hand_pose'], cnet_params['global_orient'] = pose_aa[:,3:], pose_aa[:,:3]
    cnet_params['transl'] += centroid
    # rhand_model.v_template = batch['sbj_vtemp_rh'].to(device)

mp_viewer = mp.plot(to_cpu(rhand_model(**cnet_params).vertices[0]), rhand_model.faces, name_to_rgb['gray'])
mp_viewer.add_mesh(to_cpu(pc_center), obj_mesh.faces, name_to_rgb['red'])

In [None]:
with torch.no_grad():
    net_output = body_network(
        # betas=batch['betas_body'][:,0,:],
        betas=torch.zeros((1,10), dtype=torch.float32, device=device),
        wrist_transl=transl_obj,
        gender=torch.tensor([[0]], dtype=torch.long, device=device),
        # body_pose=torch.cat([
        #     batch['fullpose_rotmat'][:,:21,:2,:],
        #     batch['fullpose_rotmat'][:,25:40,:2,:]
        # ],dim=1).view(1,-1,6),
        body_pose=rotmat2d6(aa2rotmat(torch.zeros((1,3), device=device))).unsqueeze(1).repeat(1,36,1),
        mask_ids = torch.tensor([[0] * 36 + [1]], dtype=torch.float32, device=device)
    )
    pose, trans = net_output['pose'], net_output['transl']

    pose = d62rotmat(pose).view(1,-1,9).view(1,-1,3,3)
    trans = torch.cat([torch.zeros(1,1).to(device), trans, torch.zeros(1,1).to(device)],dim=-1)
    # trans
    
    # bparams = parms_6D2full_addrh_pretrain(pose, trans, batch['fullpose_rotmat'])
    pre_rh_pose, post_rh_pose = pose[:, :21], pose[:, 21:]
    rh_global_orient_rotmat, rh_pose = cnet_params['fullpose_rotmat'][:, 0:1], cnet_params['fullpose_rotmat'][:, 1:]
    rh_rel_orient_rotmat = trans_global2loc_rh_wrist(rh_global_orient_rotmat,pre_rh_pose)
    union_pose = torch.cat([pre_rh_pose,rh_rel_orient_rotmat,post_rh_pose,rh_pose],dim=1).reshape([1, -1, 3, 3])
    
    bparams = full2bone_pretrain(union_pose,trans)
    bparams['fullpose_rotmat'] = union_pose
    body_net_output = body_model(**bparams)
    body_net_params = {
        "m_verts_full" : body_net_output.vertices,
        "m_joints_full" : body_net_output.joints,
        "m_params" : bparams
    }

mp_viewer = mp.plot(to_cpu(body_net_output.vertices[0]), body_model.faces, name_to_rgb['gray'])
mp_viewer.add_mesh(to_cpu(obj_verts[0]), obj_mesh.faces, name_to_rgb['red'])

In [None]:
optim_output = body_fit_smplx.fitting(
    # batch_,
    {
        'verts' : input_points,
        'contact_map_obj' : pre_map,
        'transl_obj' : transl_obj,
        'gender' : torch.tensor([0], dtype=torch.long, device=device),
        'transl' : trans[0],
        'transl_obj_RH' : torch.zeros((1,3), device=device),
        'global_orient_obj_RH' : rotmat2aa(new_rotmat)
    }, 
    {"cnet" : {'params' : cnet_params}}, 
    {"cnet" : body_net_params}, 
    obj_mesh.faces, 
    reserved_r
)

mp_viewer = mp.plot(to_cpu(optim_output['opt_verts'][0]), body_model.faces, name_to_rgb['gray'])
mp_viewer.add_mesh(to_cpu(obj_verts[0]), obj_mesh.faces, name_to_rgb['red'])