In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "6,"
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
from pathlib import Path
from omegaconf import OmegaConf
import numpy as np
import cv2
from PIL import Image
import torch

from bps_torch.bps import bps_torch
from smplx import SMPLXLayer
from psbody.mesh.colors import name_to_rgb
import pyrender
import meshplot as mp

from data.gnet_dataloader import LoadData
from models.cvae import gnet_model
from models.model_utils import parms_6D2full
from tools.objectmodel import ObjectModel
from tools.utils import to_cpu, to_tensor, np2torch, euler
from tools.gnet_optim import GNetOptim as FitSmplxStatic
from tools.meshviewer import Mesh
from train.GNet_inference import Trainer

device = "cuda:0"

In [2]:
torch.cuda.empty_cache()

cfg_static = OmegaConf.load("configs/GNet_orig.yaml")
cfg_static.best_model = "models/GNet_model.pt"
cfg_static.datasets.dataset_dir = "/data/3D_dataset/GrabNet/data/GRAB/data/GNet_data"
cfg_static.datasets.grab_path = "/data/3D_dataset/GrabNet/data/GRAB/data/"
cfg_static.body_model.model_path = "/data/3D_dataset/smpl_related/models/"
cfg_static.batch_size = 1
cfg_static.work_dir = "outputs/tmp"
cfg_static.cuda_id = 0

tester = Trainer(cfg=cfg_static, inference=True)

2024-09-10 00:59:37.644 | INFO     | train.GNet_inference:__init__:100 - [GNet_terminal] - Started training XXX, experiment code 2024-09-10 00:59:37
2024-09-10 00:59:37.645 | INFO     | train.GNet_inference:__init__:101 - tensorboard --logdir=outputs/tmp/summaries
2024-09-10 00:59:37.646 | INFO     | train.GNet_inference:__init__:102 - Torch Version: 1.13.1+cu116

2024-09-10 00:59:37.776 | INFO     | train.GNet_inference:__init__:116 - Using 1 CUDA cores [NVIDIA A100 80GB PCIe] for training!
2024-09-10 00:59:37.777 | INFO     | train.GNet_inference:load_data:265 - Base dataset_dir is /data/3D_dataset/GrabNet/data/GRAB/data/GNet_data


2024-09-10 00:59:40.400 | INFO     | train.GNet_inference:__init__:125 - Predict offsets: True
2024-09-10 00:59:40.405 | INFO     | train.GNet_inference:__init__:128 - Use exp function on distances: 0.0
2024-09-10 00:59:43.004 | INFO     | train.GNet_inference:loss_setup:188 - Configuring the losses!
2024-09-10 00:59:43.006 | DEBUG    | losses.losses:build_loss:45 - Building loss: l1
2024-09-10 00:59:43.006 | INFO     | train.GNet_inference:loss_setup:200 - Edge loss, weight: L1Loss(), 0.0
2024-09-10 00:59:43.007 | DEBUG    | losses.losses:build_loss:45 - Building loss: l1
2024-09-10 00:59:43.008 | INFO     | train.GNet_inference:loss_setup:207 - Vertex loss, weight: L1Loss(), 5.0
2024-09-10 00:59:43.008 | INFO     | train.GNet_inference:loss_setup:213 - Vertex consist loss weight: 0.0
2024-09-10 00:59:43.009 | DEBUG    | losses.losses:build_loss:45 - Building loss: l1
2024-09-10 00:59:43.009 | INFO     | train.GNet_inference:loss_setup:218 - Right Hand Vertex loss, weight: L1Loss(), 0

In [4]:
batch['transl']

tensor([[-0.0157,  1.3381, -0.5983]], device='cuda:0')

In [3]:
with torch.no_grad():    
    tester.network.eval()
    batch = tester.ds_test.dataset[8888]
    batch = {k:v.unsqueeze(0).to(tester.device) for k,v in batch.items()}

    gender = batch['gender'].data
    sbj_m = tester.female_model if gender == 2 else tester.male_model
    sbj_m.v_template = batch['sbj_vtemp'].to(sbj_m.v_template.device)

    obj_name = tester.data_info['test']['frame_names'][batch['idx'].to(torch.long)].split('/')[-1].split('_')[0]
    obj_path = tester.data_info['obj_info'][obj_name]['obj_mesh_file']
    obj_mesh = Mesh(filename=obj_path)
    obj_verts = torch.from_numpy(obj_mesh.vertices)
    obj_m = ObjectModel(v_template=obj_verts).to(device)

    net_output = tester.forward(batch)

fit_smplx = FitSmplxStatic(sbj_model=sbj_m, obj_model=obj_m, cfg=cfg_static, verbose=True)
optim_output = fit_smplx.fitting(batch, net_output)

sbj_verts = to_cpu(optim_output['opt_verts'][0])
obj_verts = to_cpu(fit_smplx.obj_verts[0])
mp_viewer = mp.plot(sbj_verts, sbj_m.faces, np.array([0.75,0.75,0.75]))
mp_viewer.add_mesh(obj_verts, obj_mesh.faces, np.array([1.0,0.0,0.0]))
mp_viewer

Stage:00 - Iter:0000 - Total Loss: 2.501951e-02 | [dist_rh2obj = 1.30e-02 | grnd_contact = 0.00e+00 | gaze = 1.20e-02 | global_orient = 0.00e+00 | body_pose = 0.00e+00 | left_hand_pose = 0.00e+00 | right_hand_pose = 0.00e+00 | transl = 0.00e+00]
Stage:00 - Iter:0050 - Total Loss: 1.256669e-02 | [dist_rh2obj = 7.74e-03 | grnd_contact = 0.00e+00 | gaze = 3.65e-03 | global_orient = 3.10e-05 | body_pose = 7.62e-05 | left_hand_pose = 0.00e+00 | right_hand_pose = 3.43e-04 | transl = 7.29e-04]
Stage:00 - Iter:0100 - Total Loss: 1.272583e-02 | [dist_rh2obj = 8.13e-03 | grnd_contact = 0.00e+00 | gaze = 3.64e-03 | global_orient = 1.05e-05 | body_pose = 7.56e-05 | left_hand_pose = 0.00e+00 | right_hand_pose = 3.64e-04 | transl = 5.07e-04]
Stage:00 - Iter:0150 - Total Loss: 1.052662e-02 | [dist_rh2obj = 8.27e-03 | grnd_contact = 0.00e+00 | gaze = 1.68e-03 | global_orient = 1.15e-05 | body_pose = 7.93e-05 | left_hand_pose = 0.00e+00 | right_hand_pose = 3.76e-04 | transl = 1.14e-04]


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0510825…



<meshplot.Viewer.Viewer at 0x7f5726eb5070>

In [35]:
# seed = 5321

# with torch.no_grad():
#     tester.network.eval()
#     torch.manual_seed(seed)

#     batch = next(iter(tester.ds_test))

#     obj_name = tester.data_info['test']['frame_names'][batch['idx'].to(torch.long)].split('/')[-1].split('_')[0]
#     obj_path = tester.data_info['obj_info'][obj_name]['obj_mesh_file']
#     obj_mesh = Mesh(filename=obj_path)
#     obj_m = ObjectModel(v_template=torch.from_numpy(obj_mesh.vertices)).to(device)

#     batch['transl_obj'] = torch.tensor([[-0.1, 1.3, 0.0]]).to(tester.device)
#     batch['global_orient_obj'] = batch['global_orient_obj'].to(tester.device)

#     verts_obj = obj_m(transl=batch['transl_obj'], global_orient=batch['global_orient_obj']).vertices
#     obj_bps = tester.bps['obj'].to(device) + batch['transl_obj'].reshape(1, 1, 3)
#     batch['bps_obj_glob'] = tester.bps_torch.encode(x=verts_obj,
#                         feature_type=['deltas'],
#                         custom_basis=obj_bps)['deltas']
#     batch['verts_obj'] = batch['verts_obj'][:,np.random.choice(batch['verts_obj'].shape[1], 2048, replace=False,),:].to(tester.device)
#     batch = {
#         'betas' : batch['betas'].to(tester.device),
#         'transl_obj' : batch['transl_obj'].to(tester.device),
#         'global_orient_obj' : batch['global_orient_obj'].to(tester.device),
#         'bps_obj_glob' : batch['bps_obj_glob'].to(tester.device),
        
#         'gender' : batch['gender'].to(tester.device),
#         'sbj_vtemp' : batch['sbj_vtemp'].to(tester.device),
#         'verts_obj' : batch['verts_obj'].to(tester.device),
#     }

#     net_output = tester.forward(batch)

# fit_smplx = FitSmplxStatic(sbj_model=tester.body_model, obj_model=obj_m, cfg=cfg_static, verbose=True)
# optim_output = fit_smplx.fitting(batch, net_output)

# sbj_verts = to_cpu(optim_output['opt_verts'][0])
# obj_verts = to_cpu(fit_smplx.obj_verts[0])
# mp_viewer = mp.plot(sbj_verts, tester.body_model.faces, np.array([0.75,0.75,0.75]))
# mp_viewer.add_mesh(obj_verts, obj_mesh.faces, np.array([1.0,0.0,0.0]))
# mp_viewer

Stage:00 - Iter:0000 - Total Loss: 5.282243e-01 | [dist_rh2obj = 1.70e-01 | grnd_contact = 1.48e-01 | gaze = 2.10e-01 | global_orient = 0.00e+00 | body_pose = 0.00e+00 | left_hand_pose = 0.00e+00 | right_hand_pose = 0.00e+00 | transl = 0.00e+00]
Stage:00 - Iter:0050 - Total Loss: 2.106871e-01 | [dist_rh2obj = 7.26e-02 | grnd_contact = 3.61e-02 | gaze = 3.96e-02 | global_orient = 2.99e-03 | body_pose = 8.03e-03 | left_hand_pose = 0.00e+00 | right_hand_pose = 4.27e-04 | transl = 5.09e-02]
Stage:00 - Iter:0100 - Total Loss: 1.585726e-01 | [dist_rh2obj = 6.09e-02 | grnd_contact = 5.11e-03 | gaze = 3.81e-02 | global_orient = 5.06e-03 | body_pose = 2.20e-02 | left_hand_pose = 0.00e+00 | right_hand_pose = 3.05e-04 | transl = 2.70e-02]
Stage:00 - Iter:0150 - Total Loss: 1.353094e-01 | [dist_rh2obj = 5.73e-02 | grnd_contact = 8.79e-04 | gaze = 4.03e-02 | global_orient = 5.07e-03 | body_pose = 3.11e-02 | left_hand_pose = 0.00e+00 | right_hand_pose = 9.23e-05 | transl = 5.52e-04]


Renderer(camera=PerspectiveCamera(children=(DirectionalLight(color='white', intensity=0.6, position=(0.0390731…



<meshplot.Viewer.Viewer at 0x7fe76cd96790>

In [None]:
# ds_test = LoadData(cfg_static.datasets, split_name='test')
# batch = ds_test[0]
# batch = {k: batch[k].unsqueeze(0).to(device) for k in ['idx', 'gender', 'sbj_vtemp', 'transl_obj', 'global_orient_obj', 'betas', 'bps_obj_glob']}

# ds_dir = Path("/data/3D_dataset/GrabNet/data/GRAB/data/GNet_data")

# ds = np2torch(np.load(ds_dir / 'test/GNet_data.npy', allow_pickle=True))
# frame_names = np.load(ds_dir / 'test/frame_names.npz')['frame_names']
# sbj_info = np.load(ds_dir / 'sbj_info.npy', allow_pickle=True).item()

# base_path = ds_dir.parent / 'tools/subject_meshes'
# file_list = []
# for sbj, sbj_dict in list(sbj_info.items()):
#     gender = sbj_dict['gender']
#     file_list.append(base_path / f'{gender}/{sbj}.ply')
# sbj_vtemp = torch.from_numpy(np.asarray([Mesh(filename=file).vertices.astype(np.float32) for file in file_list]))
# sbj_betas = torch.from_numpy(np.asarray([np.load(file=f.parent / f'{f.stem}_betas.npy').astype(np.float32) for f in file_list]))

# idx = 0

# frame_name = Path(frame_names[idx])
# sequence_name, obj_name = frame_name.parts[-2], frame_name.parts[-1].split("_")[0]
# batch = {k: to_tensor(ds[k][idx], dtype=torch.float32).unsqueeze(0).to(device) for k in ['transl_obj', 'global_orient_obj', 'bps_obj_glob']}
# batch['gender'] = sbj_dict['gender']

# sbj_dict = sbj_info[sequence_name]
# sbj_dict['betas'] = to_tensor(sbj_dict['betas'], dtype=torch.float32).unsqueeze(0).to(device)
# # sbj_dict['betas'] = torch.zeros([1, 1, 10], dtype=torch.float32, device=device)
# sbj_dict['vtemp'] = to_tensor(sbj_dict['vtemp'], dtype=torch.float32).unsqueeze(0).to(device)
# sbj_m = female_model if sbj_dict['gender'] == "female" else male_model
# sbj_m.v_template = sbj_dict['vtemp'] 

# obj_mesh = Mesh(filename=ds_dir.parent / f'tools/object_meshes/contact_meshes/{obj_name}.ply')
# obj_verts = torch.from_numpy(obj_mesh.vertices)
# obj_m = ObjectModel(v_template=obj_verts).to(device)
# obj_m.faces = obj_mesh.faces

In [None]:
# import smplx
# from tools.utils import parse_npz, aa2rotmat, rotmat2aa, rotate, rotmul, euler, prepare_params, params2torch

# bps = torch.load("configs/bps.pt")
# all_seqs = [x for x in ds_dir.parent.glob("grab/*/*.npz") if x.stem.split("_")[0] in ['mug', 'binoculars']]
# sequence = all_seqs[0]
# seq_data = parse_npz(sequence)

# motion_obj = params2torch(prepare_params(seq_data.object.params, frame_mask, rel_offset))

# R = torch.tensor([[1., 0., 0.], [0., 0., -1.], [0., 1., 0.]]).reshape(1, 3, 3).transpose(1,2)
# root_offset = smplx.lbs.vertices2joints(sbj_m.J_regressor, sbj_m.v_template.view(1, -1, 3))[0, 0]


# trans_obj_rel = rotate(motion_obj['transl'], R)
# global_orient_obj_rotmat = aa2rotmat(motion_obj['global_orient'])
# global_orient_obj_rel = rotmul(global_orient_obj_rotmat, R.transpose(1, 2))
# transl_obj = to_tensor(trans_obj_rel)
# global_orient_obj = rotmat2aa(to_tensor(global_orient_obj_rel).squeeze()).squeeze()

# verts_obj = obj_m(**motion_obj).vertices
# obj_bps = bps['obj'] + motion_obj['transl'].reshape(1, 1, 3)
# bps_obj = bps_torch.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

# ds_dir = Path("/data/3D_dataset/GrabNet/data/GRAB/data/GNet_data")
# ds = np2torch(np.load(ds_dir / 'test/GNet_data.npy', allow_pickle=True))
# frame_names = np.load(ds_dir / 'test/frame_names.npz')['frame_names']
# bps = torch.load("configs/bps.pt")

# # idx = 400
# obj_path = 'mug'
# obj_name = Path(frame_names[idx]).parts[-1].split("_")[0]
# # transl_obj = to_tensor(ds['transl_obj'][idx], dtype=torch.float32).unsqueeze(0).to(device)
# transl_obj = torch.tensor([[0.0, 1.0, 0.0]], dtype=torch.float32, device=device)
# # global_orient_obj = to_tensor(ds['global_orient_obj'][idx], dtype=torch.float32).unsqueeze(0).to(device)
# global_orient_obj = torch.tensor([np.deg2rad([26.009653, -108.771225, -122.78726])], dtype=torch.float32, device=device)

# obj_mesh = Mesh(filename=ds_dir.parent / f'tools/object_meshes/contact_meshes/{obj_name}.ply')
# obj_m = ObjectModel(v_template=torch.from_numpy(obj_mesh.vertices)).to(device)
# obj_m.faces = obj_mesh.faces

# verts_obj = obj_m(transl=transl_obj, global_orient=global_orient_obj).vertices[0]
# obj_bps = bps['obj'].to(device) + transl_obj.reshape(1, 1, 3)
# bps_obj_glob = bps_torch.encode(x=verts_obj, feature_type=['deltas'], custom_basis=obj_bps)['deltas']

# batch = {
#     'transl_obj' : transl_obj,
#     'global_orient_obj' : global_orient_obj,
#     'bps_obj_glob' : bps_obj_glob,
#     'gender' : 1
# }
# fit_smplx_static = FitSmplxStatic(sbj_model=body_model, obj_model=obj_m, cfg=cfg_static, verbose=True)