In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES']='4'
import numpy as np
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.ops import knn_gather
from tqdm import tqdm
import trimesh

from Closest_Point_on_Surface import DirDist_M2M
from utils.LieAlgebra import so3
from utils.deform_graph import get_deformation_graph_gdist,calculate_gdist



Define the spacial smoothness of the deformed mesh.

In [2]:
def spacial_smoothing(offset, faces, norm='l2'):
    assert norm in ['l1','l2']
    v1_idx=faces[:,0]
    v2_idx=faces[:,1]
    v3_idx=faces[:,2]
    v1_offset=offset[v1_idx]    #(M,3)
    v2_offset=offset[v2_idx]
    v3_offset=offset[v3_idx]
    if norm=='l2':
        diff_12=torch.mean((v1_offset-v2_offset)**2)
        diff_13=torch.mean((v1_offset-v3_offset)**2)
        diff_23=torch.mean((v2_offset-v3_offset)**2)
        return (diff_12+diff_13+diff_23)/3
    if norm=='l1':
        diff_12=torch.mean(torch.abs(v1_offset-v2_offset))
        diff_13=torch.mean(torch.abs(v1_offset-v3_offset))
        diff_23=torch.mean(torch.abs(v2_offset-v3_offset))
        return (diff_12+diff_13+diff_23)/3

In [3]:
def deform_vertices(nodes,deform_lie,vertices,knn_index,weights):
    #nodes:         (M,3)
    #deform_lie:    (M,6)
    #vertices:      (N,3)
    #knn_index:     (N,K)
    #weights:       (N,K)
    
    knn_nodes=knn_gather(nodes.unsqueeze(0),knn_index.long().unsqueeze(0)).squeeze(0)           #(N,K,3)
    knn_deform_lie=knn_gather(deform_lie.unsqueeze(0),knn_index.long().unsqueeze(0)).squeeze(0) #(N,K,6)

    knn_rot=so3.exp(knn_deform_lie[...,:3]) #(N,K,3,3)
    knn_trans=knn_deform_lie[...,3:]        #(N,K,3)   

    vertices_new_knn=torch.matmul(knn_rot,(vertices.unsqueeze(1)-knn_nodes).unsqueeze(-1)).squeeze(-1)+knn_nodes+knn_trans                                         
    vertices_new=torch.sum(vertices_new_knn*weights.unsqueeze(-1),dim=1)    #(N,3)

    #np.savetxt('knn_nodes.xyz',knn_nodes[0,:,:].cpu().detach().numpy())
    #np.savetxt('vertices.xyz',vertices[0:1,:].cpu().detach().numpy())
    #assert False
    return vertices_new

Define the loss function. Here we set the number of reference points as 20000, and the std as 0.05.

In [4]:
loss_func=DirDist_M2M(20000,0.05)
device='cuda'
Niter=1000
w_smooth=500
R_ratio=5
KNN=5

Optimize the vertex-wise offsets to get the deformed mesh then save it.

In [5]:
src_obj='demo_data/non_rigid_registration/mesh_0042.obj'
trg_obj='demo_data/non_rigid_registration/mesh_0044.obj'
tgt_verts, tgt_faces,_ = load_obj(trg_obj)
tgt_faces_idx = tgt_faces.verts_idx.to(device)
tgt_verts = tgt_verts.to(device)
#trg_mesh = Meshes(verts=[verts], faces=[faces_idx])

src_verts, src_faces, _ = load_obj(src_obj)
src_faces_idx = src_faces.verts_idx.to(device)
src_verts = src_verts.to(device)
#src_mesh = Meshes(verts=[verts], faces=[faces_idx])

verts_np=src_verts.cpu().numpy()
faces_np=src_faces_idx.cpu().numpy()

average_edge=np.mean(trimesh.load_mesh(src_obj).edges_unique_length)
dist_thres=float(average_edge*R_ratio)

gdist_matrix=calculate_gdist(verts_np,faces_np,dist_thres)

gdist_matrix=torch.from_numpy(gdist_matrix).cuda().float()

deformation_nodes=get_deformation_graph_gdist(src_verts,gdist_matrix,dist_thres)

square_distance=torch.sum((deformation_nodes[:,None,:]-src_verts[None,:,:])**2,dim=-1)  #(M,N)
node_idx=torch.min(square_distance,dim=1)[1]
gdist_matrix=gdist_matrix[:,node_idx]

knn_dist,knn_index=torch.topk(gdist_matrix,k=KNN,dim=1,largest=False)

weights=1-knn_dist**2/torch.tensor(dist_thres**2).to(knn_dist)
weights=torch.clamp(weights,min=0)
weights=weights**3
weights=weights/torch.sum(weights,dim=1,keepdim=True)

deform_lie = torch.full([deformation_nodes.shape[0],6], 0.0, device=device, requires_grad=True)
        
optimizer = torch.optim.SGD([deform_lie], lr=2, momentum=0.9)

scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,Niter,5e-2)

for iter in tqdm(range(1,1+Niter)):
    optimizer.zero_grad()

    new_verts=deform_vertices(deformation_nodes,deform_lie,src_verts,knn_index,weights)
                
    verts_offset=new_verts-src_verts

    #new_src_mesh = src_mesh.offset_verts(verts_offset)
    loss_geo=loss_func(new_verts, src_faces_idx, tgt_verts, tgt_faces_idx)
    loss_smooth=spacial_smoothing(verts_offset,src_faces_idx,'l2')
    loss = loss_geo +loss_smooth * w_smooth 
            
    loss.backward()
    optimizer.step()
    
    #final_verts, final_faces = new_src_mesh.get_mesh_verts_faces(0)

    # Store the predicted mesh using save_obj
    scheduler.step()

save_obj('non_rigid_reg_result.obj', new_verts, src_faces_idx)

100%|██████████| 1000/1000 [01:38<00:00, 10.19it/s]
