In [None]:
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

In [2]:
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
%matplotlib notebook 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl


from Mesh_utilities import open_mesh_multitracker
from Dcel import DCEL_Data

In [53]:
from Mesh_smoothing import extract_faces_manifolds,Compute_Area_Faces,Compute_Volume_manifold

In [29]:
device = 'cpu'
verts,faces_mm,_ = open_mesh_multitracker('txt',"8_cells_compaction.rec")
faces = faces_mm[:,:3]
faces_mm[:,[3,4]]+=1
Mesh = DCEL_Data(verts,faces_mm)


Faces_manifolds = extract_faces_manifolds(Mesh)
f1 = Faces_manifolds[1]

faces_idx = torch.tensor(f1,dtype=torch.long).to(device)
verts = torch.tensor(Mesh.v,dtype=torch.float).to(device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). 
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale

# We construct a Meshes structure for the target mesh
src_mesh = Meshes(verts=[verts], faces=[faces_idx])
target_volume = Compute_Volume_manifold(verts,faces_idx)
from Plotting_tools import plot_cells_polyscope
#plot_cells_polyscope(verts,faces_mm)

In [8]:
import polyscope as ps
ps.init()
ps.register_surface_mesh("Mesh", verts,f1)
ps.show()

In [48]:
# We will learn to deform the source mesh by offsetting its vertices
# The shape of the deform parameters is equal to the total number of vertices in src_mesh
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
# The optimizer
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)

In [51]:
# Number of optimization steps
Niter = 1000

w_area = 0.00001
w_volume = 10
# Weight for mesh edge loss
w_edge = 1.0 
# Weight for mesh normal consistency
w_normal = 0.01 
# Weight for mesh laplacian smoothing
w_laplacian = 0.1 
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))

laplacian_losses = []
edge_losses = []
normal_losses = []

%matplotlib inline

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    # Deform the mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    
    # and (b) the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_src_mesh)
    
    # mesh normal consistency
    loss_normal = mesh_normal_consistency(new_src_mesh)
    
    # mesh laplacian smoothing
    #loss_laplacian = mesh_laplacian_smoothing(new_src_mesh, method="uniform")
    
    # Weighted sum of the losses
    #loss = loss_edge * w_edge + loss_normal * w_normal + loss_laplacian * w_laplacian
    
    V = new_src_mesh.verts_packed()
    F = new_src_mesh.faces_packed()

    loss_area = (Compute_Area_Faces(V,F)).sum()**2
    loss_volume =(Compute_Volume_manifold(V,F)-target_volume)**2
    
    loss = loss_edge * w_edge + loss_normal * w_normal + loss_area * w_area + loss_volume * w_volume
    #print(loss_area.item(),loss_volume.item(),loss_edge.item(),loss_normal.item())
    #print(Compute_Volume_manifold(V,F).item(),target_volume)
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    edge_losses.append(float(loss_edge.detach().cpu()))
    normal_losses.append(float(loss_normal.detach().cpu()))
    #laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
        
    # Optimization step
    loss.backward()
    optimizer.step()


  0%|          | 0/20000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [52]:
import polyscope as ps
ps.init()
ps.register_surface_mesh("Mesh", new_src_mesh.verts_packed().detach().numpy(),new_src_mesh.faces_packed().numpy())
ps.show()

In [56]:
def view_vertex_values_on_mesh_no_trijunction(Mesh,H):
    #H : Gaussian curvature at each vertex

    v,f = Mesh.v,Mesh.f
    
    Values = np.zeros(len(Mesh.f))
    
    Mesh.mark_trijunctional_vertices()
    valid_values = []
    indices_nan = []
    for i,face in enumerate(f[:,[0,1,2]]) : 
        a,b,c = face
        liste = []
        for vert_idx in face : 
            if not Mesh.vertices[vert_idx].on_trijunction :
                liste.append(H[vert_idx])
        if len(liste)>0 : 
            Values[i]=np.mean(np.array(liste))
            valid_values.append(Values[i])
        else : 
            indices_nan.append(i)
            
    mean = np.mean(np.array(valid_values))
    for i in indices_nan : 
        Values[i]=mean
    
    ps.init()
    
    
    ps_mesh = ps.register_surface_mesh("my mesh", v,f[:,[0,1,2]])
    ps.set_ground_plane_mode("none") 
    
    Values-=np.amin(Values)
    Values/=np.amax(Values)

    colors_face = cm.jet(Values)[:,:3]
    ps_mesh.add_color_quantity("Values", colors_face, defined_on='faces',enabled=True)
    ps.show()
   
from Curvature import compute_curvature_vertices_robust_laplacian
H,_,_ = compute_curvature_vertices_robust_laplacian(Mesh)
view_vertex_values_on_mesh_no_trijunction(Mesh,H)