In [3]:
import kaolin
import torch
import math
import matplotlib
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import trimesh

from scipy.spatial import Voronoi, voronoi_plot_2d
from io import BytesIO
from PIL import Image
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.sdf_MLP as mlp
import sdfpred_utils.sdf_functions as sdf
import sdfpred_utils.loss_functions as lf
import sdfpred_utils.Steik_data3d as sd3d
import sdfpred_utils.Steik_Loss as sl
import sdfpred_utils.Steik_utils as Stu 

#cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

#default tensor types
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

multires = 2
input_dims = 3
lr_sites = 0.005/2
lr_model = 0.00005*2

iterations = 5000
save_every = 100
max_iter = 100
#learning_rate = 0.03
destination = "./images/autograd/3Dsteik/"


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Using device:  NVIDIA GeForce RTX 3090


In [4]:
#currently sites are between -5 and 5 in all 3 dimensions
# check if sites exists
#num_centroids = 16*16*16
num_centroids =16*16*16*8
num_centroids = 16*16*16

site_fp = f'sites_{num_centroids}_{input_dims}.pt'

if not os.path.exists(site_fp):
    sites = torch.load(site_fp)
    print("Sites loaded:", sites.shape)
else:
    print("Creating new sites")
    noise_scale = 0.1
    domain_limit = 1.5
    x = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
    y = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
    z = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
    meshgrid = torch.meshgrid(x, y, z)
    meshgrid = torch.stack(meshgrid, dim=3).view(-1, 3)
    print("Meshgrid shape:", meshgrid.shape)
    print("Meshgrid 1st 5:", meshgrid[:5])
    #add noise to meshgrid
    meshgrid += torch.randn_like(meshgrid) * noise_scale
    print("Meshgrid 1st 5:", meshgrid[:5])
    sites = meshgrid.to(device, dtype=torch.double).requires_grad_(True)
    
    #print min max sites 
    print("Sites min:", sites.min(dim=0).values)
    print("Sites max:", sites.max(dim=0).values)
    print("Sites shape:", sites.shape)

    #sites = su.createCVTgrid(num_centroids=num_centroids, dimensionality=input_dims)
    #save the initial sites torch tensor
    #torch.save(sites, site_fp)


def plot_voronoi_3d(sites, xlim=5, ylim=5, zlim=5):
    import numpy as np
    import pyvoro
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    # initialize random number generator
    rng = np.random.default_rng(11)
    # create a set of points in 3D
    points = sites.detach().cpu().numpy()

    # use pyvoro to compute the Voronoi tessellation
    # the second argument gives the the axis limits in x,y and z direction
    # in this case all between 0 and 1.
    # the third argument gives "dispersion = max distance between two points
    # that might be adjacent" (not sure how exactly this works)
    voronoi = pyvoro.compute_voronoi(points,[[-xlim,xlim],[-ylim,ylim],[-zlim,zlim]],1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # for each Voronoi cell, plot all the faces of the corresponding polygon
    for vnoicell in voronoi:
        faces = []
        # the vertices are the corner points of the Voronoi cell
        vertices = np.array(vnoicell['vertices'])
        # cycle through all faces of the polygon
        for face in vnoicell['faces']:
            faces.append(vertices[np.array(face['vertices'])])
            
        # join the faces into a 3D polygon
        polygon = Poly3DCollection(faces, alpha=0.5, 
                                facecolors=rng.uniform(0,1,3),
                                linewidths=0.5,edgecolors='black')
        ax.add_collection3d(polygon)
    
    ax.set_xlim([-xlim,xlim])
    ax.set_ylim([-ylim,ylim])
    ax.set_zlim([-zlim,zlim])
            
    plt.show()
    

def polyscope_sdf(model,i):
    # Render the SDF as an implicit surface (zero-level set)
    def model_sdf(pts):
        pts_tensor = torch.tensor(pts, dtype=torch.float64, device=device)
        sdf_values = model(pts_tensor)
        sdf_values_np = sdf_values.detach().cpu().numpy().flatten()  # Convert to NumPy
        
        return sdf_values_np

    ps.render_implicit_surface(f"SDF Surface {i}", model_sdf, mode="sphere_march", enabled=True, subsample_factor=2)



#plot_voronoi_3d(sites)

Creating new sites
Meshgrid shape: torch.Size([4096, 3])
Meshgrid 1st 5: tensor([[-1.5000, -1.5000, -1.5000],
        [-1.5000, -1.5000, -1.3000],
        [-1.5000, -1.5000, -1.1000],
        [-1.5000, -1.5000, -0.9000],
        [-1.5000, -1.5000, -0.7000]])
Meshgrid 1st 5: tensor([[-1.5298, -1.6393, -1.6647],
        [-1.4631, -1.4759, -1.2799],
        [-1.4324, -1.4025, -1.1389],
        [-1.4017, -1.3404, -0.8131],
        [-1.5774, -1.3658, -0.6894]])
Sites min: tensor([-1.7297, -1.7830, -1.7764], grad_fn=<MinBackward0>)
Sites max: tensor([1.9073, 1.8222, 1.8076], grad_fn=<MaxBackward0>)
Sites shape: torch.Size([4096, 3])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [5]:
ps.init()
ps_cloud = ps.register_point_cloud("initial_cvt_grid",sites.detach().cpu().numpy())

[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 3.3.0 NVIDIA 570.124.04


In [6]:
# Load the mesh
mesh = ["bunny", "Resources/stanford-bunny.obj"]
# #mesh = ["staryu", "Resources/staryu.obj"]
# #mesh = ["chair", "Resources/chair_low.obj"]

# bunny = trimesh.load(mesh[1])
# #target_points = bunny.sample(16*16*16)

# target_points = bunny.sample(num_centroids*8)
# target_points = target_points - np.mean(target_points, axis=0)
# target_points = target_points / np.max(np.abs(target_points))


# target_points = torch.tensor(target_points, device=device)
# print("Target points:", target_points.shape)
# min_target = target_points.min(0)[0]
# max_target = target_points.max(0)[0]
# print("min_target", min_target)
# print("max_target", max_target)

# ps.register_point_cloud("Target_points",target_points.detach().cpu().numpy())

#ps.show()


In [7]:
shape_type = 'bunny'
res = 128 # has to be even
example_idx = 0
sample_type = 'grid'
n_samples = 1
n_points = num_centroids*8
#TODO: change to 3D look into Steik to make it work 
# dataset = sd2d.get2D_dataset(n_points, n_samples, res, sample_type, 0.005, shape_type=shape_type)  # BasicShape2D(100, 20, res=50)
# dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=0, pin_memory=False)

dataset = sd3d.ReconDataset("Resources/stanford-bunny.obj", n_points, n_samples=n_samples)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0 ,pin_memory=False)
data = next(iter(dataloader))
print("Data keys: ", data.keys())
# mnfld_points, normals_gt, nonmnfld_dist_gt, nonmnfld_points, nonmnfld_n_gt= data['points'].to(device), data['mnfld_n'].to(device), \
#                                                                             data['nonmnfld_dist'].to(device), \
#                                                                             data['nonmnfld_points'].to(device), data['nonmnfld_n'].to(device),
mnfld_points, mnfld_n_gt, nonmnfld_points = data['points'].to(device), data['mnfld_n'].to(device),  data['nonmnfld_points'].to(device)

mnfld_points = mnfld_points[0]
mnfld_points = mnfld_points.double()
mnfld_n_gt = mnfld_n_gt[0]
mnfld_n_gt = mnfld_n_gt.double()
nonmnfld_points = nonmnfld_points[0]
nonmnfld_points = nonmnfld_points.double()

print("Manifold points shape: ", mnfld_points.shape)
print("Manifold points normals GT shape: ", mnfld_n_gt.shape)
print("Non-manifold points shape: ", nonmnfld_points.shape)

ps.register_point_cloud("mnfld", mnfld_points.detach().cpu().numpy())
ps.register_point_cloud("non mnfld", nonmnfld_points.detach().cpu().numpy())

#ps.show()
mnfld_points.requires_grad_()
nonmnfld_points.requires_grad_()
print("Manifold points shape: ", mnfld_points.shape)
print("Non-manifold shape: ", nonmnfld_points.shape)

points shape:  (32768, 3)
estimated normals shape:  (32768, 3)
Data keys:  dict_keys(['points', 'mnfld_n', 'nonmnfld_points'])
Manifold points shape:  torch.Size([32768, 3])
Manifold points normals GT shape:  torch.Size([32768, 3])
Non-manifold points shape:  torch.Size([32768, 3])
Manifold points shape:  torch.Size([32768, 3])
Non-manifold shape:  torch.Size([32768, 3])


In [8]:
model = mlp.Decoder(multires=multires, input_dims=input_dims).to(device)
radius = 3.0
#model_path = 'models_resources/pretrained_sphere_small.pth'
model_path = f'models_resources/pretrained_sphere_{radius}.pth'
#model_path = 'models_resources/trained_bunny_GT.pth'


if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print('loaded model')
else:
    print("no model found, pretraining")
    model.pre_train_sphere(1000, radius=radius)
    torch.save(model.state_dict(),model_path)



loaded model


In [9]:
def upsampling_vectorized(sites, model):
    sdf_values = model(sites)
    sites_np = sites.detach().cpu().numpy()
    # Compute Voronoi diagram
    vor = Voronoi(sites_np)
    
    neighbors = torch.tensor(np.array(vor.ridge_points), device=device)
    
    # Extract the SDF values for each site in the pair
    sdf_i = sdf_values[neighbors[:, 0]]  # First site in each pair
    sdf_j = sdf_values[neighbors[:, 1]]  # Second site in each pair
    # Find the indices where SDF values have opposing signs or one is zero
    mask_zero_crossing_sites = (sdf_i * sdf_j <= 0).squeeze()
    sites_to_upsample = torch.unique(neighbors[mask_zero_crossing_sites].view(-1))
    
    print("Sites to upsample ",sites_to_upsample.shape)
    
    tet_centroids = sites[sites_to_upsample]

    # Tetrahedron relative positions (unit tetrahedron)
    basic_tet_1 = torch.tensor([[1, 1, 1]], device=device, dtype=torch.float64)
    basic_tet_1 = basic_tet_1.repeat(len(tet_centroids), 1)
    basic_tet_2 = torch.tensor([-1, -1, 1], device=device, dtype=torch.float64)    
    basic_tet_2 = basic_tet_2.repeat(len(tet_centroids), 1)
    basic_tet_3 = torch.tensor([-1, 1, -1], device=device, dtype=torch.float64)    
    basic_tet_3 = basic_tet_3.repeat(len(tet_centroids), 1)
    basic_tet_4 = torch.tensor([1, -1, -1], device=device, dtype=torch.float64)
    basic_tet_4 = basic_tet_4.repeat(len(tet_centroids), 1)


    #compute scale based on cell volume
    centroids = torch.tensor(np.array([vor.vertices[vor.regions[vor.point_region[i]]].mean(axis=0) for i in range(len(sites_np))]), device=device)
    #centroids = torch.tensor(np.array(centroids), device=sites.device, dtype=sites.dtype)
    cells_vertices = [vor.vertices[vor.regions[vor.point_region[i]]] for i in range(len(sites_np))]

    #compute the distance between each centroid  and each vertex in cells_vertices row
    distances = []
    for i in range(len(cells_vertices)):
        min_dist = 100000000000
        for j in range(len(cells_vertices[i])):
            dist = torch.norm(centroids[i] - torch.tensor(cells_vertices[i][j], device=device), p=2)
            if dist < min_dist:
                min_dist = dist
        distances.append(min_dist)
    distances = torch.tensor(distances, device=device)
 
    
    scale = distances[sites_to_upsample] / 2
    
    scale = scale.unsqueeze(1)
    
    
    new_sites = torch.cat((tet_centroids + basic_tet_1 * scale, tet_centroids + basic_tet_2 * scale, tet_centroids + basic_tet_3 * scale, tet_centroids + basic_tet_4 * scale), dim=0)

    updated_sites = torch.cat((sites, new_sites), dim=0)

    return updated_sites
                



In [10]:
cvt_loss_values = []
min_distance_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
sdf_loss_values = []
div_loss_values = []
loss_values = []


def autograd(sites, model, max_iter=100, stop_train_threshold=1e-6, upsampling=0, lambda_weights = [0.1,1.0,0.1,0.1,1.0,1.0,0.1]):
    optimizer = torch.optim.Adam([
    {'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
    {'params': [sites], 'lr': lr_sites}
], betas=(0.9, 0.999))

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0.0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    # lambda_pc = lambda_weights[1]
    lambda_min_distance = lambda_weights[2]
    # lambda_laplace = lambda_weights[3]
    lamda_chamfer = lambda_weights[4]
    lambda_eikonal = lambda_weights[5]
    lambda_domain_restriction = lambda_weights[6]
    # lambda_target_points = lambda_weights[7]
    lambda_div = 300
    lambda_sdf = 300
    #lambda_eikonal = 50
    best_sites = sites.clone()
    best_sites.best_loss = best_loss
    
    while epoch <= max_iter:
        optimizer.zero_grad()
        
        vertices_to_compute, bisectors_to_compute = su.compute_zero_crossing_vertices_3d(sites, model)
        vertices = su.compute_vertices_3d_vectorized(sites, vertices_to_compute)    
        bisectors = su.compute_all_bisectors_vectorized(sites, bisectors_to_compute)
        points = torch.cat((vertices, bisectors), 0)

        # Compute losses       
        cvt_loss = lf.compute_cvt_loss_vectorized(sites, model)
        #min_distance_loss = lf.sdf_weighted_min_distance_loss(model, sites)
        chamfer_loss = lf.chamfer_distance(mnfld_points, points)
        domain_restriction_loss = lf.domain_restriction_sphere(mnfld_points, model, input_dim=input_dims)
        sites_loss = (
            lambda_cvt * cvt_loss +
            lamda_chamfer * chamfer_loss
        )
        
        # Compute model loss               
        non_manifold_pred = model(nonmnfld_points)
        manifold_pred = model(mnfld_points)
        div_loss = torch.tensor([0.0], device=mnfld_points.device)
        # compute gradients for div (divergence), curl and curv (curvature)
        if manifold_pred is not None:
            mnfld_grad = Stu.gradient(mnfld_points, manifold_pred)
        else:
            mnfld_grad = None

        nonmnfld_grad = Stu.gradient(nonmnfld_points, non_manifold_pred)
        div_loss = torch.abs(lf.directional_div(nonmnfld_points, nonmnfld_grad)).mean() #+ mnfld_divergence_term.mean()
        eikonal_term = lf.eikonal_loss(nonmnfld_grad, mnfld_grad=mnfld_grad, eikonal_type='abs')
        sdf_term = torch.abs(manifold_pred).mean()

        model_loss = (
            lambda_sdf*sdf_term +
            lambda_eikonal*eikonal_term +
            lambda_div*div_loss +
            lambda_domain_restriction * domain_restriction_loss
        )
        #print weights
        print("-----------------")
        print(f"lambda_sdf: {lambda_sdf}, lambda_eikonal: {lambda_eikonal}, lambda_div: {lambda_div}, lambda_domain_restriction: {lambda_domain_restriction}")
        print(f"sdf_term: {sdf_term}, eikonal_term: {eikonal_term}, div_loss: {div_loss}, domain_restriction_loss: {domain_restriction_loss}")
        print(f"Epoch {epoch}: model_loss = {model_loss.item()}")
        
        #DIVDECAY='linear' # 'linear' | 'quintic' | 'step'
        div_decay_params = [1e2, 0.2, 1e2, 0.4, 0.0, 0.0]
        div_decay_params = [300, 0.5, 100, 0.8, 0.0, 0.0]
        lambda_div = lf.update_div_weight(epoch, max_iter, lambda_div, 'linear', div_decay_params)
               

        #print weights
        print(f"cvt_loss: {cvt_loss}, chamfer_loss: {chamfer_loss}")
              #, min_distance_loss: {min_distance_loss}, 
        print(f"Epoch {epoch}: site_loss = {sites_loss.item()}")

        print(f"lambda_cvt: {lambda_cvt}, lambda_min_distance: {lambda_min_distance}, lambda_chamfer: {lamda_chamfer}")
         
        loss = sites_loss + model_loss
        loss_values.append(loss.item())
        print(f"Epoch {epoch}: loss = {loss.item()}")
                
        loss.backward()
        optimizer.step()
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = epoch
            best_sites = sites.clone()
            best_sites.best_loss = best_loss
            #if upsampled > 0:
                #print(f"UPSAMPLED {upsampled} Best Epoch {best_epoch}: Best loss = {best_loss}")
                #return best_sites
        
        if abs(prev_loss - loss.item()) < stop_train_threshold:
            print(f"Converged at epoch {epoch} with loss {loss.item()}")
            #break
        
        prev_loss = loss.item() 
        
        # if epoch>100 and (epoch // 100) == upsampled+1 and loss.item() < 0.5 and upsampled < upsampling:
        if epoch/max_iter > (0.7)*(upsampled+1)/(upsampling+1) and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ",len(sites))
            
            #new_sites = su.upsampling_inside(best_sites, model)
            #new_sites = su.adaptive_density_upsampling(best_sites, model)
            
            #sites = su.add_upsampled_sites(best_sites, new_sites)
            
            sites = upsampling_vectorized(sites, model)
            
            sites = sites.detach().requires_grad_(True)
            #print("upsampled sites length: ",len(sites))
            
            #best_sites = sites.clone()
            #best_sites.best_loss = best_loss
            
            optimizer = torch.optim.Adam([{'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
                                          {'params': [sites], 'lr': lr_sites}])
            upsampled += 1.0
            print("sites length AFTER: ",len(sites))
            
          
        if epoch % (max_iter/10) == 0:
            #print(f"Epoch {epoch}: loss = {loss.item()}")
            #print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            #save model and sites
            site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
            model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
            torch.save(model.state_dict(), model_file_path)
            torch.save(sites, site_file_path)
            
        
        epoch += 1           
        
    return best_sites

In [33]:
lambda_weights = [252,0,0,0,10,50,0.01,0.1]

lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lamda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_true_points = lambda_weights[7]

max_iter = 3000

site_file_path = f'{destination}{max_iter}_cvt_{lambda_cvt}_chamfer_{lamda_chamfer}_eikonal_{lambda_eikonal}.npy'
#check if optimized sites file exists
if os.path.exists(site_file_path):
    #import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)
else:
    import cProfile, pstats
    import time
    profiler = cProfile.Profile()
    profiler.enable()
    
    sites = autograd(sites, model, max_iter=max_iter, upsampling=2, lambda_weights=lambda_weights)
    
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumtime')
    stats.print_stats()
    stats.dump_stats(f'{destination}{mesh[0]}{max_iter}_3d_profile_{num_centroids}_chamfer{lamda_chamfer}.prof')
    
    
    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)
    
    

print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))
ps_cloud = ps.register_point_cloud("best_optimized_cvt_grid",sites.detach().cpu().numpy())
    
lim=torch.abs(torch.max(sites)).detach().cpu().numpy()*1.1
#plot_voronoi_3d(sites,lim,lim,lim)

-----------------
lambda_sdf: 300, lambda_eikonal: 50, lambda_div: 300, lambda_domain_restriction: 0.01
sdf_term: 0.1856829229087778, eikonal_term: 0.7808206143179233, div_loss: 0.30167357654908444, domain_restriction_loss: 0.012880221705329719
Epoch 0: model_loss = 185.24810935547188
cvt_loss: 0.020210590341197286, chamfer_loss: 1.8353925656411463
Epoch 0: site_loss = 23.44699442239318
lambda_cvt: 252, lambda_min_distance: 0, lambda_chamfer: 10
Epoch 0: loss = 208.69510377786506
-----------------
lambda_sdf: 300, lambda_eikonal: 50, lambda_div: 300, lambda_domain_restriction: 0.01
sdf_term: 0.17764432517479817, eikonal_term: 0.7665039502359297, div_loss: 0.30792892835380986, domain_restriction_loss: 0.010626494378365652
Epoch 1: model_loss = 183.99727983532267
cvt_loss: 0.02851068219508354, chamfer_loss: 1.8892817535789264
Epoch 1: site_loss = 26.077509448950316
lambda_cvt: 252, lambda_min_distance: 0, lambda_chamfer: 10
Epoch 1: loss = 210.07478928427298
-----------------
lambda_sdf:

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.76 GiB (GPU 0; 23.57 GiB total capacity; 15.25 GiB already allocated; 4.53 GiB free; 18.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
epoch = 1200
model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
 
sites = torch.load(site_file_path)
sites_np = sites.detach().cpu().numpy()
model.load_state_dict(torch.load(model_file_path))
#
#polyscope_sdf(model)
#
print("model", model_file_path)
print("sites", site_file_path)
ps_cloud = ps.register_point_cloud(f"{epoch} epoch_cvt_grid",sites_np)



model /home/wylliam/dev/Kyushu_experiments/images/autograd/3Dsteik/bunny2000_2000_3d_model_4096_chamfer10.pth
sites /home/wylliam/dev/Kyushu_experiments/images/autograd/3Dsteik/bunny2000_2000_3d_sites_4096_chamfer10.pth


In [13]:
final_mesh = su.get_zero_crossing_mesh_3d(sites, model)
ps.register_point_cloud("mnfld", mnfld_points.detach().cpu().numpy())

ps.register_surface_mesh("Zero-Crossing faces", final_mesh[0], final_mesh[1])
ps.register_point_cloud("Mesh vertices", final_mesh[0])
polyscope_sdf(model,2)

import scipy.spatial as spatial
from scipy.spatial import Delaunay

tri = Delaunay(sites_np)
delaunay_vertices =torch.tensor(np.array(tri.simplices), device=device)
sdf_values = model(sites)


# Assuming sites is a PyTorch tensor of shape [M, 3]
sites = sites.unsqueeze(0)  # Now shape [1, M, 3]

# Assuming SDF_Values is a PyTorch tensor of shape [M]
sdf_values = sdf_values.unsqueeze(0)  # Now shape [1, M]

marching_tetrehedra_mesh = kaolin.ops.conversions.marching_tetrahedra(sites, delaunay_vertices, sdf_values, return_tet_idx=False)
print(marching_tetrehedra_mesh)
vertices_list, faces_list = marching_tetrehedra_mesh
vertices = vertices_list[0]
faces = faces_list[0]
vertices_np = vertices.detach().cpu().numpy()  # Shape [N, 3]
faces_np = faces.detach().cpu().numpy()  # Shape [M, 3] (triangles)
ps.register_surface_mesh("Marching Tetrahedra Mesh", vertices_np, faces_np)


ps.show()


[(tensor([[-0.6560,  0.1585,  0.0153],
        [-0.6234, -0.0438, -0.0091],
        [-0.6513, -0.0637,  0.0248],
        ...,
        [ 0.7745, -0.3487,  0.2407],
        [ 0.7732, -0.3482,  0.2431],
        [ 0.7777, -0.3466,  0.2313]], grad_fn=<SumBackward1>),), (tensor([[140,  53,  52],
        [401, 413, 412],
        [307, 314, 196],
        ...,
        [581, 579, 570],
        [581, 571, 572],
        [581, 580, 571]]),)]


In [None]:
def export_visualisation_3d():
    import imageio
    img_buffer_mesh = []
    img_buffer_model = []
    for i in range(int(max_iter/10)+1):
        epoch = i*int(max_iter/10)
        
        site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
        model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
        if os.path.exists(site_file_path) and os.path.exists(model_file_path):
            print("importing sites and model")
        else:
            print("files not found")
            continue
        print("mesh of epoch: ", epoch)
        
        model.load_state_dict(torch.load(model_file_path))
    
        current_mesh = su.get_zero_crossing_mesh_3d(torch.load(site_file_path), model)
        ps.remove_all_structures()
        ps.register_surface_mesh("Zero-Crossing faces", current_mesh[0], current_mesh[1])
        ps.register_point_cloud("Mesh vertices", current_mesh[0])
        img_buffer_mesh.append(ps.screenshot_to_buffer(transparent_bg=False))
        
        ps.remove_all_structures()
        #polyscope_sdf(model)
        img_buffer_model.append(ps.screenshot_to_buffer(transparent_bg=False))


    imageio.mimsave(f'{destination}{max_iter}_3d_{num_centroids}_optimization_mesh.gif',img_buffer_mesh, fps=1, duration=1, loop=0)
    imageio.mimsave(f'{destination}{max_iter}_3d_{num_centroids}_optimization_sdf.gif', img_buffer_model, fps=1, duration=1, loop=0)

#export_visualisation_3d()