In [2]:
import kaolin
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import interactive_polyscope
from scipy.spatial import Voronoi, voronoi_plot_2d, Delaunay
from io import BytesIO
from PIL import Image
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.sdf_MLP as mlp
import sdfpred_utils.sdf_functions as sdf
import sdfpred_utils.loss_functions as lf
import trimesh
import diffvoronoi


#cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

#default tensor types
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

multires = 2
input_dims = 3
lr_sites = 0.03
lr_model = 0.0003
iterations = 5000
save_every = 100
max_iter = 100
#learning_rate = 0.03
destination = "./images/autograd/3D/TrueSDF/"
mesh = ["chair", "./Resources/chair_low.obj"]
mesh = ["chair", "./Resources/chair_low.obj"]
#mesh = ["bunny", "./Resources/stanford-bunny.obj"]



Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
Using device:  NVIDIA GeForce RTX 3050 Laptop GPU


In [3]:
#currently sites are between -5 and 5 in all 3 dimensions
# check if sites exists
#num_centroids = 16*16*16
#num_centroids = 24**3
num_centroids = 16**3
site_fp = f'sites_{num_centroids}_{input_dims}.pt'

if os.path.exists(site_fp):
    sites = torch.load(site_fp)
    print("Sites loaded:", sites.shape)
else:
    num_centroids = int(num_centroids**(1/3))
    domain = 2
    x = torch.linspace(-domain, domain, num_centroids)
    y = torch.linspace(-domain, domain, num_centroids)
    z = torch.linspace(-domain, domain, num_centroids)
    meshgrid = torch.meshgrid(x, y, z)
    meshgrid = torch.stack(meshgrid, dim=3).view(-1, 3)
    print("Meshgrid shape:", meshgrid.shape)
    print("Meshgrid 1st 5:", meshgrid[:5])
    #add noise to meshgrid
    meshgrid += torch.randn_like(meshgrid) * 0.01
    print("Meshgrid 1st 5:", meshgrid[:5])

    sites = meshgrid.to(device, dtype=torch.double).requires_grad_(True)
    
    max_dim = torch.max(sites, dim=0)[0]
    print("Max dim:", max_dim)
# else:
#     print("Creating new sites")
#     sites = su.createCVTgrid(num_centroids=num_centroids, dimensionality=input_dims)
#     #save the initial sites torch tensor
#     torch.save(sites, site_fp)


def plot_voronoi_3d(sites, xlim=5, ylim=5, zlim=5):
    import numpy as np
    import pyvoro
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    # initialize random number generator
    rng = np.random.default_rng(11)
    # create a set of points in 3D
    points = sites.detach().cpu().numpy()

    # use pyvoro to compute the Voronoi tessellation
    # the second argument gives the the axis limits in x,y and z direction
    # in this case all between 0 and 1.
    # the third argument gives "dispersion = max distance between two points
    # that might be adjacent" (not sure how exactly this works)
    voronoi = pyvoro.compute_voronoi(points,[[-xlim,xlim],[-ylim,ylim],[-zlim,zlim]],1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # for each Voronoi cell, plot all the faces of the corresponding polygon
    for vnoicell in voronoi:
        faces = []
        # the vertices are the corner points of the Voronoi cell
        vertices = np.array(vnoicell['vertices'])
        # cycle through all faces of the polygon
        for face in vnoicell['faces']:
            faces.append(vertices[np.array(face['vertices'])])
            
        # join the faces into a 3D polygon
        polygon = Poly3DCollection(faces, alpha=0.5, 
                                facecolors=rng.uniform(0,1,3),
                                linewidths=0.5,edgecolors='black')
        ax.add_collection3d(polygon)
    
    ax.set_xlim([-xlim,xlim])
    ax.set_ylim([-ylim,ylim])
    ax.set_zlim([-zlim,zlim])
            
    plt.show()

#plot_voronoi_3d(sites)

Meshgrid shape: torch.Size([3375, 3])
Meshgrid 1st 5: tensor([[-2.0000, -2.0000, -2.0000],
        [-2.0000, -2.0000, -1.7143],
        [-2.0000, -2.0000, -1.4286],
        [-2.0000, -2.0000, -1.1429],
        [-2.0000, -2.0000, -0.8571]])
Meshgrid 1st 5: tensor([[-1.9937, -1.9908, -1.9943],
        [-1.9971, -1.9894, -1.7230],
        [-2.0170, -2.0182, -1.4322],
        [-1.9855, -2.0185, -1.1252],
        [-1.9916, -1.9875, -0.8528]])
Max dim: tensor([2.0381, 2.0246, 2.0244], grad_fn=<MaxBackward0>)


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
ps.init()
#ps.register_point_cloud("initial_cvt_grid",sites.detach().cpu().numpy())


[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 3.3.0 NVIDIA 570.144


In [5]:
class SDFGrid:
    def __init__(self, filename="./Resources/dolphin.obj"):
        self.filename = filename
        self.sdf_grid = np.load(filename[:-4] + '.npy')
        self.grid = torch.tensor(self.sdf_grid, device=device)
        self.pc = self.pointcloud()

    def pointcloud(self):
        # Load the mesh
        mesh = trimesh.load(self.filename)
        # Sample points from the mesh
        points = mesh.sample(32*32*32)
        # Convert to torch tensor
        points = torch.tensor(np.array(points), dtype=torch.float64, device=device)
        points = points - torch.mean(points, dim=0)  # Center the points
        points = points / torch.max(torch.abs(points))
        return points

    def sdf(self, sites):
        gridsize = self.sdf_grid.shape[0]  # Assuming a cubic grid of size (N,N,N)
        
        # Normalize points to [0, 1] range in all dimensions
        max_dim = torch.max(sites).floor().item()
        sites = sites +  max_dim
        points_normalized = sites / (2*max_dim)
        
        
        # sites = sites + 2 #shift to 0-10
        # points_normalized = sites / 4
        

        # Scale to grid coordinates
        points_grid = points_normalized * (gridsize - 1)

        # Separate grid coordinates into integer and fractional parts
        x, y, z = points_grid[:, 0], points_grid[:, 1], points_grid[:, 2]
        x0 = x.floor().long().clamp(0, gridsize - 1)
        y0 = y.floor().long().clamp(0, gridsize - 1)
        z0 = z.floor().long().clamp(0, gridsize - 1)
        x1 = (x0 + 1).clamp(0, gridsize - 1)
        y1 = (y0 + 1).clamp(0, gridsize - 1)
        z1 = (z0 + 1).clamp(0, gridsize - 1)
        dx, dy, dz = x - x0, y - y0, z - z0

        # Perform trilinear interpolation
        values = (
            (1 - dx) * (1 - dy) * (1 - dz) * self.grid[x0, y0, z0] +
            dx * (1 - dy) * (1 - dz) * self.grid[x1, y0, z0] +
            (1 - dx) * dy * (1 - dz) * self.grid[x0, y1, z0] +
            dx * dy * (1 - dz) * self.grid[x1, y1, z0] +
            (1 - dx) * (1 - dy) * dz * self.grid[x0, y0, z1] +
            dx * (1 - dy) * dz * self.grid[x1, y0, z1] +
            (1 - dx) * dy * dz * self.grid[x0, y1, z1] +
            dx * dy * dz * self.grid[x1, y1, z1]
        )

        return values
    
    def __str__(self):
        return f"{self.filename} with shape {self.sdf_grid.shape}"

model = SDFGrid(mesh[1])
sdf_values = model.sdf(sites)
print(model)
print(sdf_values.shape)
# ps_cloud = ps.register_point_cloud("sites_w_sdf", sites.detach().cpu().numpy())
# ps_cloud.add_scalar_quantity("sdf_values", sdf_values.detach().cpu().numpy(), enabled=True)
# ps.show()

./Resources/chair_low.obj with shape (128, 128, 128)
torch.Size([3375])


In [6]:
near_target_pc = model.pc + torch.randn_like(model.pc) * 0.1
sites = torch.cat((sites, near_target_pc), dim=0)

near_target_pc = model.pc + torch.randn_like(model.pc) * 0.1
sites = torch.cat((sites, near_target_pc), dim=0)

sites = torch.tensor(sites, device=device, dtype=torch.double).requires_grad_(True)

sdf_values_trilinear = model.sdf(sites)

#ps.register_point_cloud("model_pc", model.pc.detach().cpu().numpy())

#use torch.nn.functionnal.grid_sample
#df_Values_gridsample = torch.nn.functional.grid_sample(sites, torch.tensor(sdf_points, device=device), mode='bilinear', padding_mode='border')


  sites = torch.tensor(sites, device=device, dtype=torch.double).requires_grad_(True)


In [7]:
def get_zero_crossing_mesh_3d(sites, model):
    sites_np = sites.detach().cpu().numpy()
    vor = Voronoi(sites_np)  # Compute 3D Voronoi diagram

    sdf_values = model.sdf(sites).detach().cpu().numpy()  # Compute SDF values

    valid_faces = []  # List of polygonal faces
    used_vertices = set()  # Set of indices for valid vertices

    for (point1, point2), ridge_vertices in zip(vor.ridge_points, vor.ridge_vertices):
        if -1 in ridge_vertices:
            continue  # Skip infinite ridges

        # Check if SDF changes sign across this ridge
        if np.sign(sdf_values[point1]) != np.sign(sdf_values[point2]):
            valid_faces.append(ridge_vertices)
            used_vertices.update(ridge_vertices)

    # **Filter Voronoi vertices**
    used_vertices = sorted(used_vertices)  # Keep unique, sorted indices
    vertex_map = {old_idx: new_idx for new_idx, old_idx in enumerate(used_vertices)}
    filtered_vertices = vor.vertices[used_vertices]

    # **Re-index faces to match the new filtered vertex list**
    filtered_faces = [[vertex_map[v] for v in face] for face in valid_faces]

    return filtered_vertices, filtered_faces



In [8]:
# ps_cloud = ps.register_point_cloud("sites_w_sdf", sites.detach().cpu().numpy())
# ps_cloud.add_scalar_quantity("sdf_values_trilinear", sdf_values_trilinear.detach().cpu().numpy(), enabled=True)
# # # ps_cloud.add_scalar_quantity("sdf_Values_gridsample", sdf_Values_gridsample.detach().cpu().numpy(), enabled=True)


# #ps.register_point_cloud("sites_w_zeroes_sdf_trilinear", sites.detach().cpu().numpy()[sdf_values_trilinear.detach().cpu().numpy()[:] <= 0])
# # # ps.register_point_cloud("sites_w_zeroes_sdf_gridsample", sites.detach().cpu().numpy()[sdf_Values_gridsample.detach().cpu().numpy()[:] <= 0])

initial_mesh = get_zero_crossing_mesh_3d(sites, model)
ps.register_surface_mesh("initial_mesh Zero-Crossing faces", initial_mesh[0], initial_mesh[1])
# # ps.register_point_cloud("initial_mesh vertices", initial_mesh[0])

# ps.show()


<polyscope.surface_mesh.SurfaceMesh at 0x7c6c08077280>

In [None]:
cvt_loss_values = []
min_distance_loss_values = []
edge_smoothing_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
zero_target_points_loss_values = []
sdf_loss_values = []
loss_values = []

def autograd(sites, model, max_iter=100, stop_train_threshold=1e-6, upsampling=0, lambda_weights = [0.1,1.0,0.1,0.1,1.0,1.0,0.1]):
    optimizer = torch.optim.Adam([
    {'params': [sites], 'lr': lr_sites}
], betas=(0.5, 0.999))

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    lambda_sdf = lambda_weights[1]
    lambda_min_distance = lambda_weights[2]
    lambda_laplace = lambda_weights[3]
    lambda_chamfer = lambda_weights[4]
    lambda_eikonal = lambda_weights[5]
    lambda_domain_restriction = lambda_weights[6]
    lambda_target_points = lambda_weights[7]
    
    best_sites = sites.clone()
    best_sites.best_loss = best_loss
    
    while epoch <= max_iter:
        optimizer.zero_grad()
        sites_np = sites.detach().cpu().numpy()
        
        d3dsimplices = diffvoronoi.get_delaunay_simplices(sites_np.reshape(input_dims*sites_np.shape[0]))
        d3dsimplices = np.array(d3dsimplices)
        print("d3dsimplices shape: ", d3dsimplices.shape)

        vertices_to_compute, bisectors_to_compute = su.compute_zero_crossing_vertices_3d(sites, None, None, d3dsimplices, model)
        
        vertices = su.compute_vertices_3d_vectorized(sites, vertices_to_compute)    
        bisectors = su.compute_all_bisectors_vectorized(sites, bisectors_to_compute)
        #combine vertices and bisectors to one tensor for chamfer
        points = torch.cat((vertices, bisectors), 0)
        from pytorch3d.loss import chamfer_distance
        chamfer_loss, _ = chamfer_distance(model.pc.unsqueeze(0).detach(), ((points-torch.mean(points, dim=0))/ torch.max(torch.abs(points))).unsqueeze(0))
#        print("chamfer_loss: ", chamfer_loss.item())

        # Compute losses       
        cvt_loss = lf.compute_cvt_loss_vectorized_delaunay(sites, None, d3dsimplices)
        
        sdf_loss = torch.mean(model.sdf(points)**2)        
        
        #laplacian_loss = lf.compute_ridge_smoothing_loss(bisectors_to_compute, sites, model)
        #laplacian_loss = mean_curvature_loss(ridge_vertices_pairs, sites)

        # Track raw losses (unweighted)
        cvt_loss_values.append(cvt_loss.item())
        sdf_loss_values.append(sdf_loss.item())
        #edge_smoothing_loss_values.append(laplacian_loss.item())
  
        loss = (
            lambda_cvt * cvt_loss +
            lambda_sdf * sdf_loss +
            lambda_chamfer * chamfer_loss 
            #lambda_laplace * laplacian_loss
        )
        loss_values.append(loss.item())
        #print(f"cvt_loss: {cvt_loss}, laplace_loss: {laplacian_loss}, ")
        print(f"Epoch {epoch}: loss = {loss.item()}")
                
        loss.backward()
        optimizer.step()
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = epoch
            best_sites = sites.clone()
            best_sites.best_loss = best_loss
            if upsampled > 0:
                print(f"UPSAMPLED {upsampled} Best Epoch {best_epoch}: Best loss = {best_loss}")
                #return best_sites
        
        if abs(prev_loss - loss.item()) < stop_train_threshold:
            print(f"Converged at epoch {epoch} with loss {loss.item()}")
            #break
        
        prev_loss = loss.item() 
        
        if epoch/max_iter > (0.5)*(upsampled+1)/(upsampling+1) and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ",len(sites))
            
            #new_sites = su.upsampling_inside(best_sites, model)
            #new_sites = su.adaptive_density_upsampling(best_sites, model)
            
            #sites = su.add_upsampled_sites(best_sites, new_sites)
            
            
            sites = su.upsampling_vectorized(sites, tri=None, vor=None, simplices=d3dsimplices, model=model)
            
            sites = sites.detach().requires_grad_(True)
            #print("upsampled sites length: ",len(sites))
            
            #best_sites = sites.clone()
            #best_sites.best_loss = best_loss
            
            optimizer = torch.optim.Adam([#{'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
                                          {'params': [sites], 'lr': lr_sites}])
            upsampled += 1.0
            print("sites length AFTER: ",len(sites))
            
          
        if epoch % (max_iter/10) == 0:
            print(f"Epoch {epoch}: loss = {loss.item()}")
            print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            #save model and sites
            site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}.pth'
            torch.save(sites, site_file_path)
        
        epoch += 1           
        
    return best_sites


: 

In [None]:
lambda_weights = [20.01,2,0,0,2,0,0,0]

lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lamda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_target_points = lambda_weights[7]
max_iter = 100

site_file_path = f'{destination}{mesh[0]}{max_iter}3d_sites_{num_centroids}.npy'
#check if optimized sites file exists
if not os.path.exists(site_file_path):
    #import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)    
else:
    # import cProfile, pstats
    # import time
    # profiler = cProfile.Profile()
    # profiler.enable()
    

    
    
    sites = autograd(sites, model, max_iter=max_iter, upsampling=0, lambda_weights=lambda_weights)
    
    # profiler.disable()
    # stats = pstats.Stats(profiler).sort_stats('cumtime')
    # stats.print_stats()
    # stats.dump_stats(f'{destination}{mesh[0]}{max_iter}_3d_profile_{num_centroids}.prof')
    
    
    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)
    

print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))
ps_cloud = ps.register_point_cloud("best_optimized_cvt_grid",sites.detach().cpu().numpy())
    
lim=torch.abs(torch.max(sites)).detach().cpu().numpy()*1.1
#plot_voronoi_3d(sites,lim,lim,lim)

d3dsimplices shape:  (464452, 4)


In [None]:
epoch = 100

site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}.pth'
sites = torch.load(site_file_path)
sites_np = sites.detach().cpu().numpy()
#
print("sites", site_file_path)
sdf_values_trilinear = model.sdf(sites)

ps.register_point_cloud("final_sites_w_zeroes_sdf_trilinear", sites.detach().cpu().numpy()[sdf_values_trilinear.detach().cpu().numpy()[:] <= 0])




NameError: name 'destination' is not defined

In [None]:

# tri = Delaunay(sites_np)
# delaunay_vertices =torch.tensor(np.array(tri.simplices), device=device)
# sdf_values = model.sdf(sites)


# # Assuming sites is a PyTorch tensor of shape [M, 3]
# sites = sites.unsqueeze(0)  # Now shape [1, M, 3]

# # Assuming SDF_Values is a PyTorch tensor of shape [M]
# sdf_values = sdf_values.unsqueeze(0)  # Now shape [1, M]

#marching_tetrehedra_mesh = kaolin.ops.conversions.marching_tetrahedra(sites, delaunay_vertices, sdf_values, return_tet_idx=False)
#print(marching_tetrehedra_mesh)
# vertices_list, faces_list = marching_tetrehedra_mesh
# vertices = vertices_list[0]
# faces = faces_list[0]
# vertices_np = vertices.detach().cpu().numpy()  # Shape [N, 3]
# faces_np = faces.detach().cpu().numpy()  # Shape [M, 3] (triangles)
# ps.register_surface_mesh("Marching Tetrahedra Mesh", vertices_np, faces_np)


zero_crossing_final_mesh = get_zero_crossing_mesh_3d(sites, model)
ps.register_surface_mesh("Zero-Crossing faces", zero_crossing_final_mesh[0], zero_crossing_final_mesh[1])
#ps.register_point_cloud("Mesh vertices", zero_crossing_final_mesh[0])

ps.show()



: 