In [1]:
import kaolin
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import interactive_polyscope
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
from io import BytesIO
from PIL import Image
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.sdf_MLP as mlp
import sdfpred_utils.sdf_functions as sdf
import sdfpred_utils.loss_functions as lf
import trimesh


#cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

#default tensor types
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

multires = 2
input_dims = 3
lr_sites = 0.03
lr_model = 0.0003
iterations = 5000
save_every = 100
max_iter = 100
#learning_rate = 0.03
destination = "./images/autograd/3D/TrueSDF/"
mesh = ["chair", "./Resources/chair_low.obj"]
mesh = ["chair", "./Resources/chair_low.obj"]
#mesh = ["bunny", "./Resources/stanford-bunny.obj"]



Using device:  NVIDIA GeForce RTX 3090


In [None]:
#currently sites are between -5 and 5 in all 3 dimensions
# check if sites exists
#num_centroids = 16*16*16
#num_centroids = 24**3
num_centroids = 20**3
site_fp = f'sites_{num_centroids}_{input_dims}.pt'

if os.path.exists(site_fp):
    sites = torch.load(site_fp)
    print("Sites loaded:", sites.shape)
elif num_centroids > 32*32*32:
    print("toobig for createCVTgrid")
    #create meshgrid between -5 and 5 in 3D
    num_centroids = int(num_centroids**(1/3))
    domain = 4.9
    x = torch.linspace(-domain, domain, num_centroids)
    y = torch.linspace(-domain, domain, num_centroids)
    z = torch.linspace(-domain, domain, num_centroids)
    meshgrid = torch.meshgrid(x, y, z)
    meshgrid = torch.stack(meshgrid, dim=3).view(-1, 3)
    print("Meshgrid shape:", meshgrid.shape)
    print("Meshgrid 1st 5:", meshgrid[:5])
    #add noise to meshgrid
    meshgrid += torch.randn_like(meshgrid) * 0.1
    print("Meshgrid 1st 5:", meshgrid[:5])

    sites = meshgrid.to(device, dtype=torch.double).requires_grad_(True)
else:
    print("Creating new sites")
    sites = su.createCVTgrid(num_centroids=num_centroids, dimensionality=input_dims)
    #save the initial sites torch tensor
    torch.save(sites, site_fp)


def plot_voronoi_3d(sites, xlim=5, ylim=5, zlim=5):
    import numpy as np
    import pyvoro
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    # initialize random number generator
    rng = np.random.default_rng(11)
    # create a set of points in 3D
    points = sites.detach().cpu().numpy()

    # use pyvoro to compute the Voronoi tessellation
    # the second argument gives the the axis limits in x,y and z direction
    # in this case all between 0 and 1.
    # the third argument gives "dispersion = max distance between two points
    # that might be adjacent" (not sure how exactly this works)
    voronoi = pyvoro.compute_voronoi(points,[[-xlim,xlim],[-ylim,ylim],[-zlim,zlim]],1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # for each Voronoi cell, plot all the faces of the corresponding polygon
    for vnoicell in voronoi:
        faces = []
        # the vertices are the corner points of the Voronoi cell
        vertices = np.array(vnoicell['vertices'])
        # cycle through all faces of the polygon
        for face in vnoicell['faces']:
            faces.append(vertices[np.array(face['vertices'])])
            
        # join the faces into a 3D polygon
        polygon = Poly3DCollection(faces, alpha=0.5, 
                                facecolors=rng.uniform(0,1,3),
                                linewidths=0.5,edgecolors='black')
        ax.add_collection3d(polygon)
    
    ax.set_xlim([-xlim,xlim])
    ax.set_ylim([-ylim,ylim])
    ax.set_zlim([-zlim,zlim])
            
    plt.show()

#plot_voronoi_3d(sites)

toobig for createCVTgrid
Meshgrid shape: torch.Size([16581375, 3])
Meshgrid 1st 5: tensor([[-4.9000, -4.9000, -4.9000],
        [-4.9000, -4.9000, -4.8614],
        [-4.9000, -4.9000, -4.8228],
        [-4.9000, -4.9000, -4.7843],
        [-4.9000, -4.9000, -4.7457]])
Meshgrid 1st 5: tensor([[-4.9014, -4.9803, -4.8741],
        [-4.7217, -4.9711, -4.7129],
        [-4.9745, -4.8654, -4.9572],
        [-4.8902, -4.8421, -4.6739],
        [-4.9502, -4.8006, -4.8620]])


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
ps.init()
ps.register_point_cloud("initial_cvt_grid",sites.detach().cpu().numpy())


[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 3.3.0 NVIDIA 570.124.04


<polyscope.point_cloud.PointCloud at 0x768d34108d60>

In [4]:
class SDFGrid:
    def __init__(self, filename="./Resources/dolphin.obj"):
        self.filename = filename
        self.sdf_grid = np.load(filename[:-4] + '.npy')
        self.grid = torch.tensor(self.sdf_grid, device=device)

    def sdf(self, sites):
        gridsize = self.sdf_grid.shape[0]  # Assuming a cubic grid of size (N,N,N)
        
        # Normalize points to [0, 1] range in all dimensions
        sites = sites + 5.0 #shift to 0-10
        points_normalized = sites / 10.0

        # Scale to grid coordinates
        points_grid = points_normalized * (gridsize - 1)

        # Separate grid coordinates into integer and fractional parts
        x, y, z = points_grid[:, 0], points_grid[:, 1], points_grid[:, 2]
        x0 = x.floor().long().clamp(0, gridsize - 1)
        y0 = y.floor().long().clamp(0, gridsize - 1)
        z0 = z.floor().long().clamp(0, gridsize - 1)
        x1 = (x0 + 1).clamp(0, gridsize - 1)
        y1 = (y0 + 1).clamp(0, gridsize - 1)
        z1 = (z0 + 1).clamp(0, gridsize - 1)
        dx, dy, dz = x - x0, y - y0, z - z0

        # Perform trilinear interpolation
        values = (
            (1 - dx) * (1 - dy) * (1 - dz) * self.grid[x0, y0, z0] +
            dx * (1 - dy) * (1 - dz) * self.grid[x1, y0, z0] +
            (1 - dx) * dy * (1 - dz) * self.grid[x0, y1, z0] +
            dx * dy * (1 - dz) * self.grid[x1, y1, z0] +
            (1 - dx) * (1 - dy) * dz * self.grid[x0, y0, z1] +
            dx * (1 - dy) * dz * self.grid[x1, y0, z1] +
            (1 - dx) * dy * dz * self.grid[x0, y1, z1] +
            dx * dy * dz * self.grid[x1, y1, z1]
        )

        return values
    
    def __str__(self):
        return f"{self.filename} with shape {self.sdf_grid.shape}"
    
    
# import torch
# import torch.nn.functional as F
# import numpy as np

# class SDFGrid:
#     def __init__(self, filename="./Resources/dolphin.obj", device='cuda'):
#         self.filename = filename
#         self.sdf_grid = np.load(filename[:-4] + '.npy')  # Load precomputed SDF grid
#         self.device = device
        
#         # Convert grid to a proper tensor shape (1, 1, D, H, W) for grid_sample
#         self.grid = torch.tensor(self.sdf_grid, dtype=torch.float64, device=device).unsqueeze(0).unsqueeze(0)

#     def sdf(self, sites):
#         gridsize = self.sdf_grid.shape[0]  # Assuming (N, N, N)

#         # Normalize points from world space to [-1, 1] for grid_sample
#         sites = sites + 5.0  # Shift to [0, 10] range
#         points_normalized = (2.0 * (sites / 10.0)) - 1  # Normalize to [-1, 1]

#         # Reshape for grid_sample (batch_size=1)
#         points_grid = points_normalized.view(1, 1, -1, 1, 3)  # Shape: (1, 1, N, 1, 3)

#         # Perform differentiable trilinear interpolation
#         sdf_values = F.grid_sample(self.grid, points_grid, mode='bilinear', align_corners=True, padding_mode='border')
        
#         return sdf_values.view(-1)  # Reshape to (N,)

#     def __str__(self):
#         return f"{self.filename} with shape {self.sdf_grid.shape}"



model = SDFGrid(mesh[1])
sdf_values = model.sdf(sites)
print(model)
print(sdf_values.shape)

./Resources/chair_low.obj with shape (128, 128, 128)
torch.Size([16581375])


In [5]:
#render sdf grid as a point cloud
#create a 128x128s128 grid

x = np.linspace(-1, 1, 128)
y = np.linspace(-1, 1, 128)
z = np.linspace(-1, 1, 128)
X, Y, Z = np.meshgrid(x, y, z)
points = np.vstack([X.ravel(), Y.ravel(), Z.ravel()]).T
print(points.shape)

#coordinate points with the sdf values
sdf_points = np.zeros((points.shape[0], 4))
sdf_points[:, :3] = points
sdf_points[:, 3] = model.sdf_grid.ravel()



# ps.register_point_cloud("sdf_points", sdf_points[sdf_points[:, 3] < 0][:, :3])
# ps.register_point_cloud("sdf_points_pos", sdf_points[sdf_points[:, 3] >= 0][:, :3])
# ps.show()


(2097152, 3)


In [6]:

sdf_values_trilinear = model.sdf(sites)


#use torch.nn.functionnal.grid_sample
#df_Values_gridsample = torch.nn.functional.grid_sample(sites, torch.tensor(sdf_points, device=device), mode='bilinear', padding_mode='border')


In [7]:
def get_zero_crossing_mesh_3d(sites, model):
    sites_np = sites.detach().cpu().numpy()
    vor = Voronoi(sites_np)  # Compute 3D Voronoi diagram

    sdf_values = model.sdf(sites).detach().cpu().numpy()  # Compute SDF values

    valid_faces = []  # List of polygonal faces
    used_vertices = set()  # Set of indices for valid vertices

    for (point1, point2), ridge_vertices in zip(vor.ridge_points, vor.ridge_vertices):
        if -1 in ridge_vertices:
            continue  # Skip infinite ridges

        # Check if SDF changes sign across this ridge
        if np.sign(sdf_values[point1]) != np.sign(sdf_values[point2]):
            valid_faces.append(ridge_vertices)
            used_vertices.update(ridge_vertices)

    # **Filter Voronoi vertices**
    used_vertices = sorted(used_vertices)  # Keep unique, sorted indices
    vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(used_vertices)}
    filtered_vertices = vor.vertices[used_vertices]

    # **Re-index faces to match the new filtered vertex list**
    filtered_faces = [[vertex_map[v] for v in face] for face in valid_faces]

    return filtered_vertices, filtered_faces



In [8]:
# ps_cloud = ps.register_point_cloud("sites_w_sdf", sites.detach().cpu().numpy())
# ps_cloud.add_scalar_quantity("sdf_values_trilinear", sdf_values_trilinear.detach().cpu().numpy(), enabled=True)
# # ps_cloud.add_scalar_quantity("sdf_Values_gridsample", sdf_Values_gridsample.detach().cpu().numpy(), enabled=True)


# ps.register_point_cloud("sites_w_zeroes_sdf_trilinear", sites.detach().cpu().numpy()[sdf_values_trilinear.detach().cpu().numpy()[:] <= 0])
# # ps.register_point_cloud("sites_w_zeroes_sdf_gridsample", sites.detach().cpu().numpy()[sdf_Values_gridsample.detach().cpu().numpy()[:] <= 0])

# initial_mesh = get_zero_crossing_mesh_3d(sites, model)

# ps.register_surface_mesh("initial_mesh Zero-Crossing faces", initial_mesh[0], initial_mesh[1])
# ps.register_point_cloud("initial_mesh vertices", initial_mesh[0])

#ps.show()


In [9]:
def upsampling_vectorized(sites, model):
    sdf_values = model.sdf(sites)
    sites_np = sites.detach().cpu().numpy()
    # Compute Voronoi diagram
    vor = Voronoi(sites_np)
    
    neighbors = torch.tensor(np.array(vor.ridge_points), device=device)
    
    # Extract the SDF values for each site in the pair
    sdf_i = sdf_values[neighbors[:, 0]]  # First site in each pair
    sdf_j = sdf_values[neighbors[:, 1]]  # Second site in each pair
    # Find the indices where SDF values have opposing signs or one is zero
    mask_zero_crossing_sites = (sdf_i * sdf_j <= 0).squeeze()
    sites_to_upsample = torch.unique(neighbors[mask_zero_crossing_sites].view(-1))
    
    print("Sites to upsample ",sites_to_upsample.shape)
    
    tet_centroids = sites[sites_to_upsample]

    # Tetrahedron relative positions (unit tetrahedron)
    basic_tet_1 = torch.tensor([[1, 1, 1]], device=device, dtype=torch.float64)
    basic_tet_1 = basic_tet_1.repeat(len(tet_centroids), 1)
    basic_tet_2 = torch.tensor([-1, -1, 1], device=device, dtype=torch.float64)    
    basic_tet_2 = basic_tet_2.repeat(len(tet_centroids), 1)
    basic_tet_3 = torch.tensor([-1, 1, -1], device=device, dtype=torch.float64)    
    basic_tet_3 = basic_tet_3.repeat(len(tet_centroids), 1)
    basic_tet_4 = torch.tensor([1, -1, -1], device=device, dtype=torch.float64)
    basic_tet_4 = basic_tet_4.repeat(len(tet_centroids), 1)


    #compute scale based on cell volume
    centroids = torch.tensor(np.array([vor.vertices[vor.regions[vor.point_region[i]]].mean(axis=0) for i in range(len(sites_np))]), device=device)
    #centroids = torch.tensor(np.array(centroids), device=sites.device, dtype=sites.dtype)
    cells_vertices = [vor.vertices[vor.regions[vor.point_region[i]]] for i in range(len(sites_np))]

    #compute the distance between each centroid  and each vertex in cells_vertices row
    distances = []
    for i in range(len(cells_vertices)):
        min_dist = 100000000000
        for j in range(len(cells_vertices[i])):
            dist = torch.norm(centroids[i] - torch.tensor(cells_vertices[i][j], device=device), p=2)
            if dist < min_dist:
                min_dist = dist
        distances.append(min_dist)
    distances = torch.tensor(distances, device=device)
 
    
    scale = distances[sites_to_upsample] / 2
    
    scale = scale.unsqueeze(1)
    
    
    new_sites = torch.cat((tet_centroids + basic_tet_1 * scale, tet_centroids + basic_tet_2 * scale, tet_centroids + basic_tet_3 * scale, tet_centroids + basic_tet_4 * scale), dim=0)

    updated_sites = torch.cat((sites, new_sites), dim=0)

    return updated_sites
                
def compute_zero_crossing_vertices_3d(sites, model):
    """
    Computes the indices of the sites composing vertices where neighboring sites have opposite or zero SDF values.

    Args:
        sites (torch.Tensor): (N, D) tensor of site positions.
        model (callable): Function or neural network that computes SDF values.

    Returns:
        zero_crossing_vertices_index (list of triplets): List of sites indices (si, sj, sk) where atleast 2 sites have opposing SDF signs.
    """
    # Compute Delaunay neighbors
    # Detach and convert to NumPy for Delaunay triangulation
    points_np = sites.detach().cpu().numpy()
    
    # Compute the Delaunay tessellation
    tri = Delaunay(points_np)
    vor = Voronoi(points_np)
    
    # Compute SDF values for all sites
    sdf_values = model.sdf(sites)  # Assuming model outputs (N, 1) or (N,) tensor

    neighbors = torch.tensor(np.array(vor.ridge_points), device=device)
    all_tetrahedra = torch.tensor(np.array(tri.simplices), device=device)
    #all_ridge_vertices = torch.tensor(np.array(vor.ridge_vertices), device=device)
    
    # Extract the SDF values for each site in the pair
    sdf_i = sdf_values[neighbors[:, 0]]  # First site in each pair
    sdf_j = sdf_values[neighbors[:, 1]]  # Second site in each pair
    # Find the indices where SDF values have opposing signs or one is zero
    mask_zero_crossing_sites = (sdf_i * sdf_j <= 0).squeeze()
    zero_crossing_pairs = neighbors[mask_zero_crossing_sites]
    #ridge_vertices_pairs = all_ridge_vertices[mask_zero_crossing_sites]
    #compute ridge vertices neighbors
    

    # Check if vertices has a pair of zero crossing sites
    sdf_0 = sdf_values[all_tetrahedra[:, 0]]  # First site in each pair
    sdf_1 = sdf_values[all_tetrahedra[:, 1]]  # Second site in each pair
    sdf_2 = sdf_values[all_tetrahedra[:, 2]]  # Third site in each pair
    sdf_3 = sdf_values[all_tetrahedra[:, 3]]  # Fourth site in each pair
    mask_zero_crossing_faces = (sdf_0*sdf_1<=0).squeeze() | (sdf_0*sdf_2<=0).squeeze() | (sdf_0*sdf_3<=0).squeeze() | (sdf_1*sdf_2<=0).squeeze() | (sdf_1*sdf_3<=0).squeeze() | (sdf_2*sdf_3<=0).squeeze()
    zero_crossing_vertices_index = all_tetrahedra[mask_zero_crossing_faces]
    return zero_crossing_vertices_index, zero_crossing_pairs

def compute_cvt_loss_vectorized(sites):
    # Convert sites to NumPy for Voronoi computation
    sites_np = sites.detach().cpu().numpy()
    vor = Voronoi(sites_np)
        
    #Todo C++ loop for this
    # create a nested list of vertices for each site
    centroids = [vor.vertices[vor.regions[vor.point_region[i]]].mean(axis=0) for i in range(len(sites_np)) if vor.regions[vor.point_region[i]] and -1 not in vor.regions[vor.point_region[i]]]
    centroids = torch.tensor(np.array(centroids), device=sites.device, dtype=sites.dtype)
    valid_indices = torch.tensor([i for i in range(len(sites_np)) if vor.regions[vor.point_region[i]] and -1 not in vor.regions[vor.point_region[i]]], device=sites.device)
    
    valid_sites = sites[valid_indices]
    
    penalties = torch.where(abs(valid_sites - centroids) < 10, valid_sites - centroids, torch.tensor(0.0, device=sites.device))
    
    cvt_loss = torch.mean(penalties**2)
    
    return cvt_loss

# def mean_curvature_loss(vertices, adjacency_list):
#     """
#     Computes the mean curvature loss for a given set of vertices and their adjacency list.

#     Args:
#         vertices (torch.Tensor): Tensor of shape (N, 3), where N is the number of vertices.
#         adjacency_list (list of lists): adjacency_list[i] contains indices of neighbors of vertex i.

#     Returns:
#         torch.Tensor: Scalar loss value encouraging smoother geometry.
#     """
#     device = vertices.device
#     loss = torch.tensor(0.0, device=device)

#     for i, neighbors in enumerate(adjacency_list):
#         if len(neighbors) == 0:
#             continue  # Skip isolated points
        
#         # Compute the mean of the neighboring vertices
#         neighbor_vertices = vertices[neighbors]  # Shape (num_neighbors, 3)
#         mean_neighbor = neighbor_vertices.mean(dim=0)  # Shape (3,)

#         # Mean curvature flow loss (squared distance to neighborhood mean)
#         loss += torch.norm(vertices[i] - mean_neighbor, p=2) ** 2

#     return loss / len(adjacency_list)  # Normalize by number of vertices




In [10]:
cvt_loss_values = []
min_distance_loss_values = []
edge_smoothing_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
zero_target_points_loss_values = []
sdf_loss_values = []
loss_values = []

def autograd(sites, model, max_iter=100, stop_train_threshold=1e-6, upsampling=0, lambda_weights = [0.1,1.0,0.1,0.1,1.0,1.0,0.1]):
    optimizer = torch.optim.Adam([
    {'params': [sites], 'lr': lr_sites}
], betas=(0.5, 0.999))

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    lambda_sdf = lambda_weights[1]
    lambda_min_distance = lambda_weights[2]
    lambda_laplace = lambda_weights[3]
    lamda_chamfer = lambda_weights[4]
    lamda_eikonal = lambda_weights[5]
    lambda_domain_restriction = lambda_weights[6]
    lambda_target_points = lambda_weights[7]
    
    best_sites = sites.clone()
    best_sites.best_loss = best_loss
    
    while epoch <= max_iter:
        optimizer.zero_grad()
        
        vertices_to_compute, bisectors_to_compute = compute_zero_crossing_vertices_3d(sites, model)
        vertices = su.compute_vertices_3d_vectorized(sites, vertices_to_compute)    
        bisectors = su.compute_all_bisectors_vectorized(sites, bisectors_to_compute)
        #combine vertices and bisectors to one tensor for chamfer
        points = torch.cat((vertices, bisectors), 0)


        # Compute losses       
        cvt_loss = compute_cvt_loss_vectorized(sites)

        sdf_loss = torch.mean(model.sdf(points)**2)        
        
        laplacian_loss = lf.compute_ridge_smoothing_loss(bisectors_to_compute, sites, model)
        #laplacian_loss = mean_curvature_loss(ridge_vertices_pairs, sites)

        # Track raw losses (unweighted)
        cvt_loss_values.append(cvt_loss.item())
        sdf_loss_values.append(sdf_loss.item())
        edge_smoothing_loss_values.append(laplacian_loss.item())
  
        loss = (
            lambda_cvt * cvt_loss +
            lambda_sdf * sdf_loss +
            lambda_laplace * laplacian_loss
        )
        loss_values.append(loss.item())
        print(f"cvt_loss: {cvt_loss}, laplace_loss: {laplacian_loss}, ")
        print(f"Epoch {epoch}: loss = {loss.item()}")
                
        loss.backward()
        optimizer.step()
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = epoch
            best_sites = sites.clone()
            best_sites.best_loss = best_loss
            if upsampled > 0:
                print(f"UPSAMPLED {upsampled} Best Epoch {best_epoch}: Best loss = {best_loss}")
                #return best_sites
        
        if abs(prev_loss - loss.item()) < stop_train_threshold:
            print(f"Converged at epoch {epoch} with loss {loss.item()}")
            #break
        
        prev_loss = loss.item() 
        
        if epoch/max_iter > (0.15)*(upsampled+1)/(upsampling+1) and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ",len(sites))
            
            #new_sites = su.upsampling_inside(best_sites, model)
            #new_sites = su.adaptive_density_upsampling(best_sites, model)
            
            #sites = su.add_upsampled_sites(best_sites, new_sites)
            
            sites = upsampling_vectorized(sites, model)
            
            sites = sites.detach().requires_grad_(True)
            #print("upsampled sites length: ",len(sites))
            
            #best_sites = sites.clone()
            #best_sites.best_loss = best_loss
            
            optimizer = torch.optim.Adam([#{'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
                                          {'params': [sites], 'lr': lr_sites}])
            upsampled += 1.0
            print("sites length AFTER: ",len(sites))
            
          
        if epoch % (max_iter/10) == 0:
            print(f"Epoch {epoch}: loss = {loss.item()}")
            print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            #save model and sites
            site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}.pth'
            torch.save(sites, site_file_path)
        
        epoch += 1           
        
    return best_sites


In [None]:
lambda_weights = [0.001,1,0,0,0,0,0,0]

lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lamda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_target_points = lambda_weights[7]
max_iter = 100

site_file_path = f'{destination}{mesh[0]}{max_iter}3d_sites_{num_centroids}.npy'
#check if optimized sites file exists
if os.path.exists(site_file_path):
    #import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)    
else:
    import cProfile, pstats
    import time
    profiler = cProfile.Profile()
    profiler.enable()
    
    sites = autograd(sites, model, max_iter=max_iter, upsampling=6, lambda_weights=lambda_weights)
    
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumtime')
    stats.print_stats()
    stats.dump_stats(f'{destination}{mesh[0]}{max_iter}_3d_profile_{num_centroids}.prof')
    
    
    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)
    

print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))
ps_cloud = ps.register_point_cloud("best_optimized_cvt_grid",sites.detach().cpu().numpy())
    
lim=torch.abs(torch.max(sites)).detach().cpu().numpy()*1.1
#plot_voronoi_3d(sites,lim,lim,lim)

In [None]:
epoch = 0

site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}.pth'
sites = torch.load(site_file_path)
sites_np = sites.detach().cpu().numpy()
#
print("sites", site_file_path)
sdf_values_trilinear = model.sdf(sites)

ps.register_point_cloud("final_sites_w_zeroes_sdf_trilinear", sites.detach().cpu().numpy()[sdf_values_trilinear.detach().cpu().numpy()[:] <= 0])




sites ./images/autograd/3D/TrueSDF/chair100_20_3d_sites_3375.pth


<polyscope.point_cloud.PointCloud at 0x7b25166ca850>

In [None]:

# tri = Delaunay(sites_np)
# delaunay_vertices =torch.tensor(np.array(tri.simplices), device=device)
# sdf_values = model.sdf(sites)


# # Assuming sites is a PyTorch tensor of shape [M, 3]
# sites = sites.unsqueeze(0)  # Now shape [1, M, 3]

# # Assuming SDF_Values is a PyTorch tensor of shape [M]
# sdf_values = sdf_values.unsqueeze(0)  # Now shape [1, M]

#marching_tetrehedra_mesh = kaolin.ops.conversions.marching_tetrahedra(sites, delaunay_vertices, sdf_values, return_tet_idx=False)
#print(marching_tetrehedra_mesh)
# vertices_list, faces_list = marching_tetrehedra_mesh
# vertices = vertices_list[0]
# faces = faces_list[0]
# vertices_np = vertices.detach().cpu().numpy()  # Shape [N, 3]
# faces_np = faces.detach().cpu().numpy()  # Shape [M, 3] (triangles)
# ps.register_surface_mesh("Marching Tetrahedra Mesh", vertices_np, faces_np)


zero_crossing_final_mesh = get_zero_crossing_mesh_3d(sites, model)
ps.register_surface_mesh("Zero-Crossing faces", zero_crossing_final_mesh[0], zero_crossing_final_mesh[1])
ps.register_point_cloud("Mesh vertices", zero_crossing_final_mesh[0])

ps.show()

