In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
from pathlib import Path
from omegaconf import OmegaConf
import numpy as np
import cv2
from PIL import Image
import torch

from bps_torch.bps import bps_torch
from smplx import SMPLXLayer
from psbody.mesh.colors import name_to_rgb
import pyrender
import meshplot as mp

from models.cvae import gnet_model
from models.model_utils import parms_6D2full
from tools.objectmodel import ObjectModel
from tools.utils import to_cpu, euler
from tools.gnet_optim import GNetOptim as FitSmplxStatic
from tools.meshviewer import Mesh

device = "cuda:0"

In [2]:
torch.cuda.empty_cache()

cfg_static = OmegaConf.load("configs/GNet_orig.yaml")
cfg_static.best_model = "models/GNet_model.pt"

body_model = SMPLXLayer(
    model_path="/data/3D_dataset/smpl_related/models/smplx", gender='neutral', num_pca_comps=45, flat_hand_mean=True
).to(device)
sbj_vtemp = body_model()

network_static = gnet_model(**cfg_static.network.gnet_model).eval().to(device)
network_static.cfg = cfg_static
network_static.load_state_dict(torch.load(cfg_static.best_model, map_location=device), strict=False)

bps_torch_model = bps_torch()
bps = torch.load("configs/bps.pt")

In [7]:
obj_path = '/shared/3D/GrabFusion/data/dex-ycb-models/002_master_chef_can/textured_simple.obj'
input_trans = [0.0,  1.6, 0.0] # rel. to root joint of smplx
input_rot = np.deg2rad([90,0,0])

transl_obj = torch.tensor([input_trans], dtype=torch.float32, device=device)
global_orient_obj = torch.tensor([input_rot], dtype=torch.float32, device=device)
transl_obj = sbj_vtemp.joints[0,0] +transl_obj

obj_mesh = Mesh(filename=obj_path, vscale=1.0)
obj_verts = obj_mesh.vertices
obj_m = ObjectModel(v_template=torch.from_numpy(obj_verts)).to(device)
obj_m.faces = obj_mesh.faces

verts_obj = obj_m(transl=transl_obj, global_orient=global_orient_obj).vertices[0]
obj_bps = bps['obj'].to(device) + transl_obj.reshape(1, 1, 3)
bps_obj_glob = bps_torch_model.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

batch = {
    'transl_obj' : transl_obj,
    'global_orient_obj' : global_orient_obj,
    'bps_obj_glob' : bps_obj_glob,
    'gender' : 1
}
fit_smplx_static = FitSmplxStatic(sbj_model=body_model, obj_model=obj_m, cfg=cfg_static, verbose=True)

with torch.no_grad():
    sbj_output = body_model(transl=torch.tensor([[ 0.0,  1.25, -0.5]], device='cuda:0'))
    verts_sbj = sbj_output.vertices[0].detach()

mp_viewer = mp.plot(to_cpu(verts_sbj)-[0.0,0.75,0.0], body_model.faces, np.array([0.75,0.75,0.75]))
mp_viewer.add_mesh(to_cpu(verts_obj)-[0.0,0.75,0.0], obj_mesh.faces, np.array([1.0,0.0,0.0]))
mp_viewer

Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(5.0663948…

<meshplot.Viewer.Viewer at 0x7f93ea783b50>

In [None]:
torch.manual_seed(5171)

z_enc_s = torch.distributions.normal.Normal(
    loc=torch.zeros([1, 16], requires_grad=False).to(device),
    scale=torch.ones([1, 16], requires_grad=False).to(device)
).rsample()

dec_x = {
    'betas' : torch.zeros([1, 1, 10], dtype=torch.float32, device=device), 
    'transl_obj' : batch['transl_obj'], 
    'bps_obj' : batch['bps_obj_glob'].norm(dim=-1),
    'z' : z_enc_s
}
dec_x = torch.cat([v.reshape(1, -1).to(device) for v in dec_x.values()], dim=1)

net_output = network_static.decode(dec_x)
bparams = parms_6D2full(net_output['pose'], net_output['trans'], d62rot=True)
net_output[f'm_params'] = bparams

cnet_verts, cnet_s_verts = fit_smplx_static.get_smplx_verts(batch, {"cnet" : net_output})
optim_output = fit_smplx_static.fitting(batch, {"cnet" : net_output})

sbj_verts = to_cpu(optim_output['opt_verts'][0])
obj_verts = to_cpu(fit_smplx_static.obj_verts[0])

mp_viewer = mp.plot(sbj_verts, body_model.faces, np.array([0.75,0.75,0.75]))
mp_viewer.add_mesh(obj_verts, obj_mesh.faces, np.array([1.0,0.0,0.0]))
mp_viewer

In [None]:
angle = [0,0,0]
trans = [0.0, 1.0, 0.0]
z_dist = 2.0
a_light = 0.4
d_light = 3.0
w, h = 512, 512

# sbj_verts = to_cpu(cnet_verts[0]) - np.array([trans])
sbj_verts = to_cpu(optim_output['opt_verts'][0]) - np.array([trans])
obj_verts = to_cpu(fit_smplx_static.obj_verts[0]) - np.array([trans])
sbj_opt = Mesh(vertices=sbj_verts, faces=body_model.faces, vc=[0.5,0.5,0.5])
obj_i = Mesh(vertices=obj_verts, faces=obj_mesh.faces, vc=name_to_rgb['red'])
obj_array = [sbj_opt, obj_i]

scene = pyrender.Scene(bg_color=[0.0,0.0,0.0,1.0], ambient_light=a_light, name='scene')

pc = pyrender.PerspectiveCamera(yfov=np.pi / 3.0, aspectRatio=1.0)
camera_pose = np.eye(4)
camera_pose[:3, :3] = euler([0.0, 0.0, 0.0], 'xzy')
camera_pose[:3, 3] = np.array([0.0, 0.0, z_dist])
cam = pyrender.Node(name = 'camera', camera=pc, matrix=camera_pose)
scene.add_node(cam)

light = pyrender.light.DirectionalLight(color=np.ones(3), intensity=d_light)
light = pyrender.Node(light=light, matrix=camera_pose)
scene.add_node(light)

for obj in obj_array:
    obj.rot_verts(euler(angle, 'xzy'))
    mesh = pyrender.Mesh.from_trimesh(obj)
    scene.add(mesh)

viewer = pyrender.OffscreenRenderer(w, h)
color, depth_buffer = viewer.render(scene)
viewer.delete()

depth = depth_buffer.copy()
mask = depth > 0
color_image = Image.fromarray(np.concatenate([color, (mask[...,np.newaxis]*255.).astype(np.uint8)], axis=-1))
color_image

In [None]:
# ds_test = LoadData(cfg_static.datasets, split_name='test')
# batch = ds_test[0]
# batch = {k: batch[k].unsqueeze(0).to(device) for k in ['idx', 'gender', 'sbj_vtemp', 'transl_obj', 'global_orient_obj', 'betas', 'bps_obj_glob']}

# ds_dir = Path("/data/3D_dataset/GrabNet/data/GRAB/data/GNet_data")

# ds = np2torch(np.load(ds_dir / 'test/GNet_data.npy', allow_pickle=True))
# frame_names = np.load(ds_dir / 'test/frame_names.npz')['frame_names']
# sbj_info = np.load(ds_dir / 'sbj_info.npy', allow_pickle=True).item()

# base_path = ds_dir.parent / 'tools/subject_meshes'
# file_list = []
# for sbj, sbj_dict in list(sbj_info.items()):
#     gender = sbj_dict['gender']
#     file_list.append(base_path / f'{gender}/{sbj}.ply')
# sbj_vtemp = torch.from_numpy(np.asarray([Mesh(filename=file).vertices.astype(np.float32) for file in file_list]))
# sbj_betas = torch.from_numpy(np.asarray([np.load(file=f.parent / f'{f.stem}_betas.npy').astype(np.float32) for f in file_list]))

# idx = 0

# frame_name = Path(frame_names[idx])
# sequence_name, obj_name = frame_name.parts[-2], frame_name.parts[-1].split("_")[0]
# batch = {k: to_tensor(ds[k][idx], dtype=torch.float32).unsqueeze(0).to(device) for k in ['transl_obj', 'global_orient_obj', 'bps_obj_glob']}
# batch['gender'] = sbj_dict['gender']

# sbj_dict = sbj_info[sequence_name]
# sbj_dict['betas'] = to_tensor(sbj_dict['betas'], dtype=torch.float32).unsqueeze(0).to(device)
# # sbj_dict['betas'] = torch.zeros([1, 1, 10], dtype=torch.float32, device=device)
# sbj_dict['vtemp'] = to_tensor(sbj_dict['vtemp'], dtype=torch.float32).unsqueeze(0).to(device)
# sbj_m = female_model if sbj_dict['gender'] == "female" else male_model
# sbj_m.v_template = sbj_dict['vtemp'] 

# obj_mesh = Mesh(filename=ds_dir.parent / f'tools/object_meshes/contact_meshes/{obj_name}.ply')
# obj_verts = torch.from_numpy(obj_mesh.vertices)
# obj_m = ObjectModel(v_template=obj_verts).to(device)
# obj_m.faces = obj_mesh.faces

In [None]:
# import smplx
# from tools.utils import parse_npz, aa2rotmat, rotmat2aa, rotate, rotmul, euler, prepare_params, params2torch

# bps = torch.load("configs/bps.pt")
# all_seqs = [x for x in ds_dir.parent.glob("grab/*/*.npz") if x.stem.split("_")[0] in ['mug', 'binoculars']]
# sequence = all_seqs[0]
# seq_data = parse_npz(sequence)

# motion_obj = params2torch(prepare_params(seq_data.object.params, frame_mask, rel_offset))

# R = torch.tensor([[1., 0., 0.], [0., 0., -1.], [0., 1., 0.]]).reshape(1, 3, 3).transpose(1,2)
# root_offset = smplx.lbs.vertices2joints(sbj_m.J_regressor, sbj_m.v_template.view(1, -1, 3))[0, 0]


# trans_obj_rel = rotate(motion_obj['transl'], R)
# global_orient_obj_rotmat = aa2rotmat(motion_obj['global_orient'])
# global_orient_obj_rel = rotmul(global_orient_obj_rotmat, R.transpose(1, 2))
# transl_obj = to_tensor(trans_obj_rel)
# global_orient_obj = rotmat2aa(to_tensor(global_orient_obj_rel).squeeze()).squeeze()

# verts_obj = obj_m(**motion_obj).vertices
# obj_bps = bps['obj'] + motion_obj['transl'].reshape(1, 1, 3)
# bps_obj = bps_torch.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

# ds_dir = Path("/data/3D_dataset/GrabNet/data/GRAB/data/GNet_data")
# ds = np2torch(np.load(ds_dir / 'test/GNet_data.npy', allow_pickle=True))
# frame_names = np.load(ds_dir / 'test/frame_names.npz')['frame_names']
# bps = torch.load("configs/bps.pt")

# # idx = 400
# obj_path = 'mug'
# obj_name = Path(frame_names[idx]).parts[-1].split("_")[0]
# # transl_obj = to_tensor(ds['transl_obj'][idx], dtype=torch.float32).unsqueeze(0).to(device)
# transl_obj = torch.tensor([[0.0, 1.0, 0.0]], dtype=torch.float32, device=device)
# # global_orient_obj = to_tensor(ds['global_orient_obj'][idx], dtype=torch.float32).unsqueeze(0).to(device)
# global_orient_obj = torch.tensor([np.deg2rad([26.009653, -108.771225, -122.78726])], dtype=torch.float32, device=device)

# obj_mesh = Mesh(filename=ds_dir.parent / f'tools/object_meshes/contact_meshes/{obj_name}.ply')
# obj_m = ObjectModel(v_template=torch.from_numpy(obj_mesh.vertices)).to(device)
# obj_m.faces = obj_mesh.faces

# verts_obj = obj_m(transl=transl_obj, global_orient=global_orient_obj).vertices[0]
# obj_bps = bps['obj'].to(device) + transl_obj.reshape(1, 1, 3)
# bps_obj_glob = bps_torch.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

# batch = {
#     'transl_obj' : transl_obj,
#     'global_orient_obj' : global_orient_obj,
#     'bps_obj_glob' : bps_obj_glob,
#     'gender' : 1
# }
# fit_smplx_static = FitSmplxStatic(sbj_model=body_model, obj_model=obj_m, cfg=cfg_static, verbose=True)

In [None]:
# import time

# obj_path = '/shared/3D/GrabFusion/data/dex-ycb-models/002_master_chef_can/textured_simple.obj'
# input_trans = [0.0,  1.6, 0.0]
# input_rot = np.deg2rad([90,0,0])

# time_array = []
# for _ in range(10):
#     # obj_verts, obj_faces, obj_img = pipeline3d.get_obj_verts(obj_path, obj_rot=obj_rot, obj_trans=obj_trans, re_scale=obj_scale)
#     stime = time.time()
#     transl_obj = torch.tensor([input_trans], dtype=torch.float32, device=device)
#     global_orient_obj = torch.tensor([input_rot], dtype=torch.float32, device=device)

#     obj_mesh = Mesh(filename=obj_path, vscale=1.0)
#     obj_verts = obj_mesh.vertices
#     obj_m = ObjectModel(v_template=torch.from_numpy(obj_verts)).to(device)
#     obj_m.faces = obj_mesh.faces

#     verts_obj = obj_m(transl=transl_obj, global_orient=global_orient_obj).vertices[0]
#     obj_bps = bps['obj'].to(device) + transl_obj.reshape(1, 1, 3)
#     bps_obj_glob = bps_torch.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

#     batch = {
#         'transl_obj' : transl_obj,
#         'global_orient_obj' : global_orient_obj,
#         'bps_obj_glob' : bps_obj_glob,
#         'gender' : 1
#     }
#     fit_smplx_static = FitSmplxStatic(sbj_model=body_model, obj_model=obj_m, cfg=cfg_static, verbose=True)

#     torch.manual_seed(5171)

#     z_enc_s = torch.distributions.normal.Normal(
#         loc=torch.zeros([1, 16], requires_grad=False).to(device),
#         scale=torch.ones([1, 16], requires_grad=False).to(device)
#     ).rsample()

#     dec_x = {
#         'betas' : torch.zeros([1, 1, 10], dtype=torch.float32, device=device), 
#         'transl_obj' : batch['transl_obj'], 
#         'bps_obj' : batch['bps_obj_glob'].norm(dim=-1),
#         'z' : z_enc_s
#     }
#     dec_x = torch.cat([v.reshape(1, -1).to(device) for v in dec_x.values()], dim=1)

#     net_output = network_static.decode(dec_x)
#     bparams = parms_6D2full(net_output['pose'], net_output['trans'], d62rot=True)
#     net_output[f'm_params'] = bparams

#     cnet_verts, cnet_s_verts = fit_smplx_static.get_smplx_verts(batch, {"cnet" : net_output})
#     optim_output = fit_smplx_static.fitting(batch, {"cnet" : net_output})

#     sbj_verts = to_cpu(optim_output['opt_verts'][0])
#     obj_verts = to_cpu(fit_smplx_static.obj_verts[0])

#     time_array.append(time.time() - stime)

# np.array(time_array).mean()