In [12]:
import kaolin
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import diffvoronoi
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.loss_functions as lf

#cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

input_dims = 3
lr_sites = 0.005
lr_model = 0.00001
destination = "./images/autograd/End2End_DCCVT/"
model_trained_it = ""

mesh = ["gargoyle","/home/wylliam/dev/Kyushu_experiments/data/gargoyle"]
trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-04-24-18-16-03/gargoyle/gargoyle/trained_models/model{model_trained_it}.pth"

# mesh = ["chair","/home/wylliam/dev/Kyushu_experiments/data/chair"]
# trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-05-02-17-56-25/chair/chair/trained_models/model{model_trained_it}.pth"

#mesh = ["bunny","/home/wylliam/dev/Kyushu_experiments/data/bunny"]
#trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-04-25-17-32-49/bunny/bunny/trained_models/model{model_trained_it}.pth"


Using device:  NVIDIA GeForce RTX 3090


In [13]:
from pytorch3d.ops import knn_points, knn_gather
import torch
from torch import nn

# class Voroloss_opt(nn.Module):
#     def __init__(self):
#         super(Voroloss_opt, self).__init__()
#         self.knn = 16

#     def __call__(self, points, spoints):
#         """points, self.points"""
#         # WARNING: fecthing for knn
#         with torch.no_grad():
#             indices = knn_points(points, spoints, K=self.knn).idx

#         points_knn = knn_gather(spoints, indices)
#         points_to_voronoi_center = points - points_knn[:, :, 0]

#         voronoi_edge = points_knn[:, :, 1:] - points_knn[:, :, 0].unsqueeze(2)
#         voronoi_edge_l = torch.sqrt(((voronoi_edge**2).sum(-1)))
#         vector_length = (points_to_voronoi_center.unsqueeze(2) * voronoi_edge).sum(
#             -1
#         ) / voronoi_edge_l
#         sq_dist = (vector_length - voronoi_edge_l / 2) ** 2
#         return sq_dist.min(-1)[0]
    
voroloss = lf.Voroloss_opt().to(device)

In [14]:
num_centroids = 8**3
grid = 32
print("Creating new sites")
noise_scale = 0.1
domain_limit = 1
x = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
y = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
z = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids**(1/3))))
meshgrid = torch.meshgrid(x, y, z)
meshgrid = torch.stack(meshgrid, dim=3).view(-1, 3)


#add noise to meshgrid
#meshgrid += torch.randn_like(meshgrid) * noise_scale


sites = meshgrid.to(device, dtype=torch.float32).requires_grad_(True)

print("Sites shape: ", sites.shape)
ps.init()


Creating new sites
Sites shape:  torch.Size([512, 3])


In [15]:
#LOAD MODEL WITH HOTSPOT
import sys
sys.path.append("3rdparty/HotSpot")
from dataset import shape_3d
import models.Net as Net

loss_type = "igr_w_heat"
loss_weights = [350, 0, 0, 1, 0, 0, 20]

train_set = shape_3d.ReconDataset(
    file_path = mesh[1]+".ply",
    n_points=grid*grid*150,#15000, #args.n_points,
    n_samples=10001, #args.n_iterations,
    grid_res=256, #args.grid_res,
    grid_range=1.1, #args.grid_range,
    sample_type="uniform_central_gaussian", #args.nonmnfld_sample_type,
    sampling_std=0.5, #args.nonmnfld_sample_std,
    n_random_samples=7500, #args.n_random_samples,
    resample=True,
    compute_sal_dist_gt=(
        True if "sal" in loss_type and loss_weights[5] > 0 else False
    ),
    scale_method="mean"#"mean" #args.pcd_scale_method,
)

model = Net.Network(
    latent_size=0,#args.latent_size,
    in_dim=3,
    decoder_hidden_dim=128,#args.decoder_hidden_dim,
    nl="sine",#args.nl,
    encoder_type="none",#args.encoder_type,
    decoder_n_hidden_layers=5,#args.decoder_n_hidden_layers,
    neuron_type="quadratic",#args.neuron_type,
    init_type="mfgi",#args.init_type,
    sphere_init_params=[1.6, 0.1],#args.sphere_init_params,
    n_repeat_period=30#args.n_repeat_period,
)
model.to(device)

######       
test_dataloader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)   
test_data = next(iter(test_dataloader))
mnfld_points = test_data["mnfld_points"].to(device)
mnfld_points.requires_grad_()
print("mnfld_points shape: ", mnfld_points.shape)
if torch.cuda.is_available():
    map_location = torch.device("cuda")
else:
    map_location = torch.device("cpu")
model.load_state_dict(torch.load(trained_model_path, weights_only=True, map_location=map_location))

mnfld_points shape:  torch.Size([1, 153600, 3])


  return self.fget.__get__(instance, owner)()


<All keys matched successfully>

In [16]:
# #add mnfld points with random noise to sites 
N = mnfld_points.squeeze(0).shape[0]
num_samples = grid**3 - num_centroids
idx = torch.randint(0, N, (num_samples,))
sampled = mnfld_points.squeeze(0)[idx]
perturbed = sampled + (torch.rand_like(sampled)-0.5)*0.05
sites = torch.cat((sites, perturbed), dim=0)

# make sites a leaf tensor
sites = sites.detach().requires_grad_()
print(sites.dtype)
print(f"Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")


torch.float32
Allocated: 605.983232 MB, Reserved: 2208.301056 MB


In [17]:
sites_pred = model(sites).detach()#["nonmanifold_pnts_pred"]
print(f"Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")

#mnfld_preds = model(mnfld_points)#["nonmanifold_pnts_pred"]

ps_cloud = ps.register_point_cloud("initial_cvt_grid+pc_gt",sites.detach().cpu().numpy())
mnf_cloud = ps.register_point_cloud("mnfld_points_pred",mnfld_points.squeeze(0).detach().cpu().numpy())
#mnf_cloud.add_scalar_quantity("mnfld_points_pred", mnfld_preds.reshape(-1).detach().cpu().numpy(), enabled=True)
ps_cloud.add_scalar_quantity("vis_grid_pred", sites_pred.reshape(-1).detach().cpu().numpy(), enabled=True)

#initial_mesh = su.get_zero_crossing_mesh_3d(sites, model)
#ps.register_surface_mesh("initial Zero-Crossing faces", initial_mesh[0], initial_mesh[1])

#v_vect, f_vect = su.get_clipped_mesh_torch(sites, model, None, batch_size=4096)
v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, True)
triangle_faces = [[f[0], f[i], f[i+1]] for f in f_vect for i in range(1, len(f)-1)]
ps.register_surface_mesh("initial triangle clipped mesh", v_vect.detach().cpu().numpy(), triangle_faces)

v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, False)
triangle_faces = [[f[0], f[i], f[i+1]] for f in f_vect for i in range(1, len(f)-1)]
ps.register_surface_mesh("initial triangle mesh", v_vect.detach().cpu().numpy(), triangle_faces)
ps.show()

Allocated: 605.983232 MB, Reserved: 2208.301056 MB


In [None]:
# SITES OPTIMISATION LOOP
cvt_loss_values = []
min_distance_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
sdf_loss_values = []
div_loss_values = []
loss_values = []



def train_DCCVT(sites, model, max_iter=100, stop_train_threshold=1e-6, upsampling=0, lambda_weights = [0.1,1.0,0.1,0.1,1.0,1.0,0.1]):
    optimizer = torch.optim.Adam([
    {'params': [sites], 'lr': lr_sites},
    #{'params': model.parameters(), 'lr': lr_model}
])
    
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 150, 200, 250], gamma=0.5)

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0.0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    lambda_chamfer = lambda_weights[4]
    best_sites = sites.clone()
    best_sites.best_loss = best_loss
    
    while epoch <= max_iter:
        optimizer.zero_grad()
        
        # sites_np = sites.detach().cpu().numpy()
        # d3dsimplices = diffvoronoi.get_delaunay_simplices(sites_np.reshape(input_dims*sites_np.shape[0]))
        # d3dsimplices = np.array(d3dsimplices)

        # vertices_to_compute, bisectors_to_compute = su.compute_zero_crossing_vertices_3d(sites, None, None, d3dsimplices, model)
        # vertices = su.compute_vertices_3d_vectorized(sites, vertices_to_compute)    
        # bisectors = su.compute_all_bisectors_vectorized(sites, bisectors_to_compute)
        # points = torch.cat((vertices, bisectors), 0)
        # print("points", points.shape) 
    
        # cvt_loss = lf.compute_cvt_loss_vectorized_delaunay(sites, None, d3dsimplices)
        # print("CVT loss: ", cvt_loss, "weighted: ", lambda_cvt*cvt_loss)
        # #min_distance_loss = lf.sdf_weighted_min_distance_loss(model, sites)
        
        # from pytorch3d.loss import chamfer_distance
        # chamfer_loss_points, _ = chamfer_distance(mnfld_points.detach(), points.unsqueeze(0))
        # print(f"Points Chamfer loss PYTORCH3D {chamfer_loss_points} weighted: {lambda_chamfer*chamfer_loss_points} : Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")

        # for param in model.parameters():
        #     param.requires_grad = False
        # s1 = torch.mean(model(points)**2)
        # s2 = torch.maximum((model(sites).abs() - 0.05), torch.tensor(0.0)).mean()
        # sdf_loss = 0*s1+s2
        # sdf_loss.backward(retain_graph=True)
        # for param in model.parameters():
        #     param.requires_grad = True
            
        # #print("SDF loss: ", sdf_loss, "weighted: ", lambda_chamfer/10*sdf_loss)
        
        # #v_vect, f_vect = su.get_clipped_mesh_torch(sites, model, d3dsimplices, batch_size=4096)
        # v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, d3dsimplices)
        
        
        # #ps.register_surface_mesh("polygon clipped mesh", v_vect.detach().cpu().numpy(), f_vect)

        # # fanning to transform polygon faces to triangle faces
        # triangle_faces = [[f[0], f[i], f[i+1]] for f in f_vect for i in range(1, len(f)-1)]
        # #ps.register_surface_mesh("triangle clipped mesh", v_vect.detach().cpu().numpy(), triangle_faces)

        # triangle_faces = torch.tensor(triangle_faces, device=device)
        # hs_p = su.sample_mesh_points_heitz(v_vect, triangle_faces, num_samples=32*32*150)
        # print("hs_p shape: ", hs_p.shape)
        # #ps.register_point_cloud("heitz clipped mesh", hs_p.detach().cpu().numpy())
        
        # from pytorch3d.loss import chamfer_distance
        # chamfer_loss_mesh, _ = chamfer_distance(mnfld_points.detach(), hs_p.unsqueeze(0))
        # print(f"Mesh Chamfer loss PYTORCH3D {chamfer_loss_mesh} weighted: {lambda_chamfer*chamfer_loss_mesh} : Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")


        voroloss_loss = voroloss(mnfld_points.squeeze(0), sites).mean()
        # #voroloss_loss = voroloss(sites.unsqueeze(0), mnfld_points)
        
        
        # # triangle area loss
        # # 1) Gather triangle vertices
        # v0 = v_vect[triangle_faces[:, 0]]  # (F,3)
        # v1 = v_vect[triangle_faces[:, 1]]  # (F,3)
        # v2 = v_vect[triangle_faces[:, 2]]  # (F,3)

        # # 2) Compute triangle areas for weighting
        # e0 = v1 - v0               # (F,3)
        # e1 = v2 - v0               # (F,3)
        # cross = torch.cross(e0, e1, dim=1)  # (F,3)
        # areas = 0.5 * cross.norm(dim=1)     # (F,)
        # mean_area = areas.mean()  # (1,)
        # triangle_area_loss = torch.mean(areas-mean_area)**2
        # print("triangle loss: ", triangle_area_loss, "weighted: ", lambda_chamfer*0.01*triangle_area_loss)


        sites_loss = (
            #lambda_cvt * cvt_loss +
            #lambda_chamfer * chamfer_loss_mesh 
            #+ lambda_chamfer*0.01 * triangle_area_loss
            #+ lambda_chamfer * chamfer_loss_points
            lambda_chamfer * voroloss_loss
            #+ lambda_chamfer/10 * sdf_loss
        )
            
        loss = sites_loss
        loss_values.append(loss.item())
        print(f"Epoch {epoch}: loss = {loss.item()}")
        print(f"before loss.backward(): Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")

        loss.backward()
        print(f"After loss.backward(): Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")
        print("-----------------")
        
        optimizer.step()
        scheduler.step()
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = epoch
            best_sites = sites.clone()
            best_sites.best_loss = best_loss
            #if upsampled > 0:
                #print(f"UPSAMPLED {upsampled} Best Epoch {best_epoch}: Best loss = {best_loss}")
                #return best_sites
        
        if abs(prev_loss - loss.item()) < stop_train_threshold:
            print(f"Converged at epoch {epoch} with loss {loss.item()}")
            #break
        
        prev_loss = loss.item() 
        
        # if epoch>100 and (epoch // 100) == upsampled+1 and loss.item() < 0.5 and upsampled < upsampling:
        if epoch/max_iter > (upsampled+1)/(upsampling+1) and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ",len(sites))
            sites = su.upsampling_vectorized(sites, tri=None, vor=None, simplices=d3dsimplices, model=model)
            sites = sites.detach().requires_grad_(True)
            optimizer = torch.optim.Adam([{'params': [sites], 'lr': lr_sites}, 
                                          #{'params': model.parameters(), 'lr': lr_model}
                                          ])
            upsampled += 1.0
            print("sites length AFTER: ",len(sites))
            
          
        if epoch % (max_iter/10) == 0:
            #print(f"Epoch {epoch}: loss = {loss.item()}")
            #print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            #save model and sites
            #ps.register_surface_mesh(f"{epoch} triangle clipped mesh", v_vect.detach().cpu().numpy(), triangle_faces.detach().cpu().numpy())
            
            site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}.pth'
            model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lambda_chamfer}.pth'
            torch.save(model.state_dict(), model_file_path)
            torch.save(sites, site_file_path)
            
        epoch += 1           
    
    #Export the sites, their sdf values, the gradients of the sdf values and the hessian
    sdf_values = model(sites)

    sdf_gradients = torch.autograd.grad(outputs=sdf_values, inputs=sites, grad_outputs=torch.ones_like(sdf_values), create_graph=True, retain_graph=True,)[0] # (N, 3)

    N, D = sites.shape
    hess_sdf = torch.zeros(N, D, D, device=sites.device)
    for i in range(D):
        grad2 = torch.autograd.grad(outputs=sdf_gradients[:, i], inputs=sites, grad_outputs=torch.ones_like(sdf_gradients[:, i]), create_graph=False, retain_graph=True,)[0] # (N, 3)
        hess_sdf[:, i, :] = grad2 # fill row i of each 3Ã—3 Hessian
    
    np.savez(f'{mesh[0]}voroloss_to_clip{model_trained_it}.npz', sites=sites.detach().cpu().numpy(), sdf_values=sdf_values.detach().cpu().numpy(), sdf_gradients=sdf_gradients.detach().cpu().numpy(), sdf_hessians=hess_sdf.detach().cpu().numpy())
    print(f"Saved to {mesh[0]}voroloss_to_clip{model_trained_it}.npz")
    return sites

In [19]:
#lambda_weights = [252,0,0,0,10.211111,0,100,0]
#lambda_weights = [500,0,0,0,1000,0,100,0]
lambda_weights = [100,0,0,0,1000,0,100,0]


lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lambda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_true_points = lambda_weights[7]

max_iter = 400

In [20]:
site_file_path = f'{destination}{max_iter}_cvt_{lambda_cvt}_chamfer_{lambda_chamfer}_eikonal_{lambda_eikonal}.npy'
#check if optimized sites file exists
if not os.path.exists(site_file_path):
    #import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)
else:
    # import cProfile, pstats
    # import time
    # profiler = cProfile.Profile()
    # profiler.enable()

#     with torch.profiler.profile(activities=[
#             torch.profiler.ProfilerActivity.CPU,
#             torch.profiler.ProfilerActivity.CUDA,
#         ],
#         record_shapes=False,
#         with_stack=True  # Captures function calls
#     ) as prof:
#         sites = train_DCCVT(sites, model, max_iter=max_iter, upsampling=1, lambda_weights=lambda_weights)
#         torch.cuda.synchronize()
# # 
#     print(prof.key_averages().table(sort_by="self_cuda_time_total"))
#     prof.export_chrome_trace("trace.json")
    
    # 
    sites = train_DCCVT(sites, model, max_iter=max_iter, upsampling=0, lambda_weights=lambda_weights)

    
    # profiler.disable()
    # stats = pstats.Stats(profiler).sort_stats('cumtime')
    # stats.print_stats()
    # stats.dump_stats(f'{destination}{mesh[0]}{max_iter}_3d_profile_{num_centroids}_chamfer{lambda_chamfer}.prof')
    
    
    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)
    
print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))

Epoch 0: loss = 0.007484010420739651
before loss.backward(): Allocated: 686.322176 MB, Reserved: 3783.262208 MB
After loss.backward(): Allocated: 608.549376 MB, Reserved: 3783.262208 MB
-----------------
Epoch 1: loss = 0.006939945742487907
before loss.backward(): Allocated: 688.552448 MB, Reserved: 3783.262208 MB
After loss.backward(): Allocated: 609.335808 MB, Reserved: 3783.262208 MB
-----------------
Epoch 2: loss = 0.00553774693980813
before loss.backward(): Allocated: 688.552448 MB, Reserved: 3783.262208 MB
After loss.backward(): Allocated: 609.335808 MB, Reserved: 3783.262208 MB
-----------------
Epoch 3: loss = 0.004699479788541794
before loss.backward(): Allocated: 688.552448 MB, Reserved: 3783.262208 MB
After loss.backward(): Allocated: 609.335808 MB, Reserved: 3783.262208 MB
-----------------
Epoch 4: loss = 0.004056088160723448
before loss.backward(): Allocated: 688.552448 MB, Reserved: 3783.262208 MB
After loss.backward(): Allocated: 609.335808 MB, Reserved: 3783.262208 MB

In [23]:
epoch = 400

model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lambda_chamfer}.pth'
site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}.pth'
 
sites = torch.load(site_file_path)

sites_np = sites.detach().cpu().numpy()
model.load_state_dict(torch.load(model_file_path))
#
#polyscope_sdf(model)
#
print("model", model_file_path)
print("sites", site_file_path)
ps_cloud = ps.register_point_cloud(f"{epoch} epoch_cvt_grid",sites_np)

print("sites_np shape: ", sites_np.shape)

#print sites if Nan
if np.isnan(sites_np).any():
    print("sites_np contains NaN values")
    print("sites_np NaN values: ", np.isnan(sites_np).sum())
#remove nan values from sites tensor
sites_np = sites_np[~np.isnan(sites_np).any(axis=1)]
sites = torch.from_numpy(sites_np).to(device).requires_grad_(True)

model ./images/autograd/End2End_DCCVT/gargoyle400_400_3d_model_512_chamfer1000.pth
sites ./images/autograd/End2End_DCCVT/gargoyle400_400_3d_sites_512_chamfer1000.pth
sites_np shape:  (32768, 3)


In [24]:
final_mesh = su.get_zero_crossing_mesh_3d(sites, model)
#ps.register_surface_mesh("Zero-Crossing faces direct", final_mesh[0], final_mesh[1])

#save to file
final_mesh_file = f'{mesh[0]}voroloss_sdf_trained{model_trained_it}.npz'
faces = np.array(final_mesh[1], dtype=object)
np.savez(final_mesh_file, vertices=final_mesh[0], faces=faces)

data = np.load(final_mesh_file, allow_pickle=True)
verts = data["vertices"]       # (N_vertices, 3)
faces = data["faces"].tolist() # back to a list of lists

# print("Zero-Crossing faces final shape: ", verts.shape)
# ps.register_surface_mesh("Zero-Crossing faces final", verts, faces, back_face_policy="identical")

#v_vect, f_vect = su.get_clipped_mesh_torch(sites, model, None, batch_size=3072)
v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, True)
#only for voroloss case
ps.register_surface_mesh("final clipped polygon mesh", v_vect.detach().cpu().numpy(), f_vect)
# fanning to transform polygon faces to triangle faces
triangle_faces = [[f[0], f[i], f[i+1]] for f in f_vect for i in range(1, len(f)-1)]
ps.register_surface_mesh("final clipped triangle mesh", v_vect.detach().cpu().numpy(), triangle_faces, back_face_policy="identical")


v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, False)

#only for voroloss case
ps.register_surface_mesh("final polygon mesh", v_vect.detach().cpu().numpy(), f_vect)
# fanning to transform polygon faces to triangle faces
triangle_faces = [[f[0], f[i], f[i+1]] for f in f_vect for i in range(1, len(f)-1)]
ps.register_surface_mesh("final triangle mesh", v_vect.detach().cpu().numpy(), triangle_faces, back_face_policy="identical")

# triangle_faces = torch.tensor(triangle_faces, device=device)
# s_p = su.sample_mesh_points(v_vect, triangle_faces, num_samples=150*32**2)
# ps.register_point_cloud("sampled clipped mesh", s_p.detach().cpu().numpy())

# hs_p = su.sample_mesh_points_heitz(v_vect, triangle_faces, num_samples=150*32**2)
# ps.register_point_cloud("heitz clipped mesh", hs_p.detach().cpu().numpy())

# ##register original mesh
# mesh_file = mesh[1]+".stl"
# #load mesh 
# m = trimesh.load(mesh_file)
# #convert to numpy
# mesh_np = np.array(m.vertices)
# #normalize mesh
# mesh_np = mesh_np - np.mean(mesh_np, axis=0)
# mesh_np = mesh_np / np.max(np.abs(mesh_np))
# mesh_faces = np.array(m.faces)
# ps.register_surface_mesh("Original Mesh", mesh_np, mesh_faces)


import scipy.spatial as spatial
from scipy.spatial import Delaunay

tri = Delaunay(sites_np)
delaunay_vertices =torch.tensor(np.array(tri.simplices), device=device)
sdf_values = model(sites)

# Assuming sites is a PyTorch tensor of shape [M, 3]
sites = sites.unsqueeze(0)  # Now shape [1, M, 3]

# Assuming SDF_Values is a PyTorch tensor of shape [M]
sdf_values = sdf_values.unsqueeze(0)  # Now shape [1, M]

marching_tetrehedra_mesh = kaolin.ops.conversions.marching_tetrahedra(sites, delaunay_vertices, sdf_values, return_tet_idx=False)
#print(marching_tetrehedra_mesh)
vertices_list, faces_list = marching_tetrehedra_mesh
vertices = vertices_list[0]
faces = faces_list[0]
vertices_np = vertices.detach().cpu().numpy()  # Shape [N, 3]
faces_np = faces.detach().cpu().numpy()  # Shape [M, 3] (triangles)
ps.register_surface_mesh("Marching Tetrahedra Mesh final", vertices_np, faces_np)

# clipped_cvt = "clipped_CVT.obj"
# if os.path.exists(clipped_cvt):
#     clipped_cvt_mesh = trimesh.load(clipped_cvt)
#     ps.register_surface_mesh("Clipped CVT", clipped_cvt_mesh.vertices, clipped_cvt_mesh.faces)
ps.show()