In [None]:
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.

In [1]:
import os
import torch
from pytorch3d.io import load_obj, save_obj
from pytorch3d.structures import Meshes
from pytorch3d.utils import ico_sphere
from pytorch3d.ops import sample_points_from_meshes
from pytorch3d.loss import (
    chamfer_distance, 
    mesh_edge_loss, 
    mesh_laplacian_smoothing, 
    mesh_normal_consistency,
)
import numpy as np
from tqdm.notebook import tqdm
%matplotlib notebook 
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import matplotlib as mpl


from Mesh_utilities import open_mesh_multitracker,separate_faces
from Dcel import DCEL_Data
from Mesh_smoothing import *#extract_faces_manifolds,Compute_Area_Faces,Compute_Volume_manifold

In [2]:
device = 'cpu'
verts,faces_mm,_ = open_mesh_multitracker('txt',"8_cells_compaction.rec")
faces = faces_mm[:,:3]
faces_mm[:,[3,4]]+=1
Mesh = DCEL_Data(verts,faces_mm)


Faces_manifolds = extract_faces_manifolds(Mesh)

faces_idx = torch.tensor(Mesh.f[:,:3],dtype=torch.long).to(device)
verts = torch.tensor(Mesh.v,dtype=torch.float).to(device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1 centered at (0,0,0). 
# (scale, center) will be used to bring the predicted mesh to its original center and scale
# Note that normalizing the target mesh, speeds up the optimization but is not necessary!
center = verts.mean(0)
verts = verts - center
scale = max(verts.abs().max(0)[0])
verts = verts / scale

src_mesh = Meshes(verts=[verts], faces=[faces_idx])

# We construct a Meshes structure for the target mesh
from Plotting_tools import plot_cells_polyscope
#plot_cells_polyscope(verts,faces_mm)

In [4]:
Edges_trijunctions = Mesh.find_trijunctional_edges()

Number of trijunctional edges : 269 5


In [5]:
Meshes_list = []
Target_volumes = []

for i in range(1,len(Faces_manifolds)):
    faces_idx = torch.tensor(Faces_manifolds[i],dtype=torch.long).to(device)
    Meshes_list.append(Meshes(verts=[verts], faces=[faces_idx]))
    Target_volumes.append(Compute_Volume_manifold(verts,faces_idx))

In [19]:
# We will learn to deform the source mesh by offsetting its vertices
# The shape of the deform parameters is equal to the total number of vertices in src_mesh
deform_verts = torch.full(src_mesh.verts_packed().shape, 0.0, device=device, requires_grad=True)
# The optimizer
optimizer = torch.optim.SGD([deform_verts], lr=1.0, momentum=0.9)

# On va fonctionner de la manière suivante : on aura autant de mesh pytorch que de cellules, et on va calculer les loss volume et normal sur chacune

In [22]:
# Number of optimization steps
Niter = 1000

w_area = 0.00001
w_line = 0.00000001
w_volume = 10
# Weight for mesh edge loss
w_edge = 1.0 
# Weight for mesh normal consistency
w_normal = 0.01 
# Weight for mesh laplacian smoothing
w_laplacian = 0.1 
# Plot period for the losses
plot_period = 250
loop = tqdm(range(Niter))

laplacian_losses = []
edge_losses = []
normal_losses = []

%matplotlib inline

for i in loop:
    # Initialize optimizer
    optimizer.zero_grad()
    
    
    """
    GLOBAL LOSSES : 
    """
    
    # Deform the whole mesh
    new_src_mesh = src_mesh.offset_verts(deform_verts)
    
    #loss area : 
    V = new_src_mesh.verts_packed()
    F = new_src_mesh.faces_packed()
    loss_area = (Compute_Area_Faces(V,F)).sum()**2
    
    # the edge length of the predicted mesh
    loss_edge = mesh_edge_loss(new_src_mesh)
    
    loss_line = (Compute_length_edges_trijunctions(V,Edges_trijunctions)).sum()**2
    
    """
    CELL-SPECIFIC LOSSES : 
    """
    
    losses_normals = []
    losses_volume = []
    new_cells_mesh = []
    for i,cell_mesh in enumerate(Meshes_list) : 
        new_cells_mesh.append(cell_mesh.offset_verts(deform_verts))
        losses_normals.append(mesh_normal_consistency(new_cells_mesh[-1]))
        
        V = new_cells_mesh[-1].verts_packed()
        F = new_cells_mesh[-1].faces_packed()
        target_volume = Target_volumes[i]
        losses_volume.append((Compute_Volume_manifold(V,F)-target_volume)**2)
    
   

    loss = loss_edge * w_edge + loss_area * w_area + loss_line * w_line
    
    for i in range(len(Meshes_list)): 
        loss = loss + losses_volume[i] * w_volume + losses_normals[i] * w_normal 
        
        
    loss_volume = torch.sum(torch.tensor(losses_volume))
    loss_normal = torch.sum(torch.tensor(losses_normals))
    print(loss_area.item(),loss_volume.item(),loss_edge.item(),loss_normal.item(),loss_line.item())
    #print(Compute_Volume_manifold(V,F).item(),target_volume)
    # Print the losses
    loop.set_description('total_loss = %.6f' % loss)
    
    # Save the losses for plotting
    #edge_losses.append(float(loss_edge.detach().cpu()))
    #normal_losses.append(float(loss_normal.detach().cpu()))
    #laplacian_losses.append(float(loss_laplacian.detach().cpu()))
    
        
    # Optimization step
    loss.backward()
    optimizer.step()


  0%|          | 0/1000 [00:00<?, ?it/s]

313.2655944824219 4.96334109811869e-07 0.011458984576165676 0.42556464672088623 736442.625
313.1990966796875 3.2837476737768156e-07 0.011434598825871944 0.4199637174606323 735118.25
313.1571044921875 2.2091259666012775e-07 0.011410387232899666 0.4150676131248474 733797.625
313.1119384765625 1.7327447210391256e-07 0.011385847814381123 0.41047653555870056 732452.75
313.0460510253906 1.8144969260447397e-07 0.011360663920640945 0.40593889355659485 731066.125
312.9534606933594 2.4466010017931694e-07 0.011334773153066635 0.4012075960636139 729636.0625
312.8397521972656 3.6155788052383286e-07 0.011308442801237106 0.3967781364917755 728177.5625
312.7123107910156 5.13133272761479e-07 0.011282028630375862 0.39287132024765015 726708.1875
312.5876159667969 6.542028927469801e-07 0.011255927383899689 0.38910049200057983 725250.5625
312.47955322265625 7.382926696664072e-07 0.011230501346290112 0.3852301836013794 723823.75
312.3965759277344 7.437948852384579e-07 0.011206083931028843 0.3817086517810821

KeyboardInterrupt: 

In [23]:
plot_cells_polyscope(new_src_mesh.verts_packed().detach().numpy(),faces_mm)