In [None]:
import torch
import igl
import meshplot as mp
import numpy as np

from models.SkinningNet import SkinningNet
from utils.OneRingIdx import getFacesOneRingIdx,getSkeletonOneRingIdx
from utils.Skeleton import getBonesEdge
from utils.Visualization import showWeight
from utils.SaveFile import save_obj
from deformation.DiffusionFlow import smoothWeight
from deformation.GetTransferMatrix import skeletonTransferWithVirtualJoints,AddVirtualJoints, skeletonTransferWithVirtualJoints_Q,skeletonTransferWithVirtualJoints_Animal,AddVirtualJoints_Animal
from deformation.SkinningDeformation import LBS,DQS,DQS_pytorch
from datasets.smpl import SMPLLayer
from utils.Garment_mapping import get_garment_weight
from SkeletonPoseTransfer import SkeletonPoseTransfer

from deformation.dqs import dqs

device=torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
##### 1.Input Data

source_V, source_F = igl.read_triangle_mesh('./test_data/smpl0.obj')
source_F_idx = torch.tensor(getFacesOneRingIdx(source_F)).unsqueeze(0).to(device).long()
source_S_idx = torch.tensor(getSkeletonOneRingIdx()).unsqueeze(0).to(device).long()
source_V = torch.tensor(source_V).unsqueeze(0).to(device).float()


refer_V, refer_F = igl.read_triangle_mesh('./test_data/smpl1.obj')
refer_F_idx = torch.tensor(getFacesOneRingIdx(refer_F)).unsqueeze(0).to(device).long()
refer_S_idx = torch.tensor(getSkeletonOneRingIdx()).unsqueeze(0).to(device).long()
refer_V = torch.tensor(refer_V).unsqueeze(0).to(device).float()

In [None]:
##### 2.Neural Network

model=SkinningNet(jointNum=24).to(device)
model.load_state_dict(torch.load('./statedict/skinningNet_finetune_noise.pkl'))
# model=SkinningNet(jointNum=33).to(device)
# model.load_state_dict(torch.load('./statedict/skinningNet_finetuen_animal.pkl'))
model.eval()
print("model load")
source_J, source_Att = model(source_V, source_F_idx)
source_rigW = source_Att.detach()
source_rigW = (source_rigW == source_rigW.max(dim = 1, keepdim = True)[0]).float().cpu().permute(0,2,1).squeeze(0).numpy()

refer_J, refer_Att = model(refer_V, refer_F_idx)
refer_rigW = refer_Att.detach()
refer_rigW = (refer_rigW == refer_rigW.max(dim = 1, keepdim = True)[0]).float().cpu().permute(0,2,1).squeeze(0).numpy()

print('source joints')
viewer = mp.plot(source_J[0].detach().cpu().numpy(),shading={'point_color':'blue','point_size':0.3})
viewer.add_points(source_V[0].detach().cpu().numpy(),shading={'point_color':'red','point_size':0.05})
viewer.add_edges(source_J[0].detach().cpu().numpy(),getBonesEdge(), shading={"line_color":"green"})
print('refer joints')
viewer = mp.plot(refer_J[0].detach().cpu().numpy(),shading={'point_color':'blue','point_size':0.3})
viewer.add_points(refer_V[0].detach().cpu().numpy(),shading={'point_color':'red','point_size':0.05})
viewer.add_edges(refer_J[0].detach().cpu().numpy(),getBonesEdge(), shading={"line_color":"green"})

In [None]:
##### 3.Deformation

source_V = source_V[0].cpu().numpy()
source_J = source_J[0].detach().cpu().numpy()
source_rigW = source_Att.detach()
source_rigW = (source_rigW == source_rigW.max(dim = 1, keepdim = True)[0]).float().cpu().permute(0,2,1).squeeze(0).numpy()
source_smoothW = smoothWeight(source_V, source_F, source_rigW)

refer_V = refer_V[0].cpu().numpy()
refer_J = refer_J[0].detach().cpu().numpy()
refer_rigW = refer_Att.detach()
refer_rigW = (refer_rigW == refer_rigW.max(dim = 1, keepdim = True)[0]).float().cpu().permute(0,2,1).squeeze(0).numpy()
refer_smoothW = smoothWeight(refer_V, refer_F, refer_rigW)

print('source weight')
viewer = showWeight(
    source_V,
    source_F,
    source_rigW
)
print('refer weight')
viewer = showWeight(
    refer_V,
    refer_F,
    refer_rigW
)


transfer_matrix = SkeletonPoseTransfer(
    AddVirtualJoints(torch.tensor(source_J).unsqueeze(0), torch.tensor(source_V).unsqueeze(0), torch.tensor(source_rigW).unsqueeze(0)),
    AddVirtualJoints(torch.tensor(refer_J).unsqueeze(0), torch.tensor(refer_V).unsqueeze(0), torch.tensor(refer_rigW).unsqueeze(0))
)


LBS_result = LBS(
    torch.tensor(source_V).unsqueeze(0),
    torch.tensor(source_smoothW).unsqueeze(0),
    transfer_matrix
)


In [None]:
##### 4.Plot

print('source')
mp.plot(source_V,source_F)

print('reference')
mp.plot(refer_V,refer_F)

print('LBS')
mp.plot(LBS_result[0].numpy(),source_F)
