In [8]:
import kaolin
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import interactive_polyscope
from scipy.spatial import Voronoi, voronoi_plot_2d
from io import BytesIO
from PIL import Image
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.sdf_MLP as mlp
import sdfpred_utils.sdf_functions as sdf
import sdfpred_utils.loss_functions as lf
import trimesh
from scipy.spatial import Delaunay, Voronoi


#cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

#default tensor types
torch.set_default_tensor_type(torch.cuda.DoubleTensor)

multires = 2
input_dims = 3
lr_sites = 0.03
lr_model = 0.0003
iterations = 5000
save_every = 100
max_iter = 100
#learning_rate = 0.03
destination = "./images/autograd/3D/"


Using device:  NVIDIA GeForce RTX 3090


In [9]:
#currently sites are between -5 and 5 in all 3 dimensions
# check if sites exists
#num_centroids = 16*16*16
num_centroids =16*16*16
site_fp = f'sites_{num_centroids}_{input_dims}.pt'

if os.path.exists(site_fp):
    sites = torch.load(site_fp)
    print("Sites loaded:", sites.shape)
else:
    print("Creating new sites")
    sites = su.createCVTgrid(num_centroids=num_centroids, dimensionality=input_dims)
    #save the initial sites torch tensor
    torch.save(sites, site_fp)


def plot_voronoi_3d(sites, xlim=5, ylim=5, zlim=5):
    import numpy as np
    import pyvoro
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d.art3d import Poly3DCollection
    # initialize random number generator
    rng = np.random.default_rng(11)
    # create a set of points in 3D
    points = sites.detach().cpu().numpy()

    # use pyvoro to compute the Voronoi tessellation
    # the second argument gives the the axis limits in x,y and z direction
    # in this case all between 0 and 1.
    # the third argument gives "dispersion = max distance between two points
    # that might be adjacent" (not sure how exactly this works)
    voronoi = pyvoro.compute_voronoi(points,[[-xlim,xlim],[-ylim,ylim],[-zlim,zlim]],1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    # for each Voronoi cell, plot all the faces of the corresponding polygon
    for vnoicell in voronoi:
        faces = []
        # the vertices are the corner points of the Voronoi cell
        vertices = np.array(vnoicell['vertices'])
        # cycle through all faces of the polygon
        for face in vnoicell['faces']:
            faces.append(vertices[np.array(face['vertices'])])
            
        # join the faces into a 3D polygon
        polygon = Poly3DCollection(faces, alpha=0.5, 
                                facecolors=rng.uniform(0,1,3),
                                linewidths=0.5,edgecolors='black')
        ax.add_collection3d(polygon)
    
    ax.set_xlim([-xlim,xlim])
    ax.set_ylim([-ylim,ylim])
    ax.set_zlim([-zlim,zlim])
            
    plt.show()

#plot_voronoi_3d(sites)

Sites loaded: torch.Size([4096, 3])


In [10]:
ps.init()
#ps_cloud = ps.register_point_cloud("initial_cvt_grid",sites.detach().cpu().numpy())

In [11]:

# Load the mesh
mesh = ["bunny", "Resources/stanford-bunny.obj"]
#mesh = ["staryu", "Resources/staryu.obj"]
#mesh = ["chair", "Resources/chair_low.obj"]

bunny = trimesh.load(mesh[1])

# Get current bounding box
min_bound = bunny.bounds[0]  # Min (x, y, z)
max_bound = bunny.bounds[1]  # Max (x, y, z)

# Compute scale factor
current_size = max_bound - min_bound  # Size in each dimension
target_size = 8  # Because we want [-5, 5], the total size is 10

scale_factor = target_size / np.max(current_size)  # Scale based on the largest dimension

# Compute new center after scaling
new_vertices = bunny.vertices * scale_factor  # Scale the vertices
new_min = np.min(new_vertices, axis=0)
new_max = np.max(new_vertices, axis=0)
new_center = (new_min + new_max) / 2  # New center after scaling

# Compute translation to center the bunny at (0,0,0)
translation = -new_center  # Move to the origin

# Apply transformation (scaling + translation)
bunny.vertices = new_vertices + translation

#target_points = bunny.sample(16*16*16)
target_points = bunny.sample(num_centroids*8)
target_points = torch.tensor(target_points, device=device)
print("Target points:", target_points.shape)
min_target = target_points.min(0)[0]
max_target = target_points.max(0)[0]
print("min_target", min_target)
print("max_target", max_target)

#ps_cloud = ps.register_point_cloud("Target_points",target_points.detach().cpu().numpy())


Target points: torch.Size([32768, 3])
min_target tensor([-3.9994, -3.9458, -3.0972])
max_target tensor([3.9998, 3.9602, 3.0996])


In [12]:
# ps.show()

In [None]:
import matplotlib
import os

model = mlp.Decoder(multires=multires, input_dims=input_dims).to(device)
#model_path = 'models_resources/pretrained_sphere_small.pth'
model_path = 'models_resources/pretrained_sphere_small_eiko.pth'
#model_path = 'models_resources/trained_bunny_GT.pth'


if os.path.exists(model_path):
    model.load_state_dict(torch.load(model_path))
    print('loaded model')
else:
    print("no model found, pretraining")
    model.pre_train_sphere(1000,2)
    #model.train_GT_mesh(300, mesh, target_points)
    #model.train_GT_grid(300)
    torch.save(model.state_dict(),model_path)
    
def polyscope_sdf(model):
    # Render the SDF as an implicit surface (zero-level set)
    def model_sdf(pts):
        pts_tensor = torch.tensor(pts, dtype=torch.float64, device=device)
        sdf_values = model(pts_tensor)
        sdf_values_np = sdf_values.detach().cpu().numpy().flatten()  # Convert to NumPy
        
        return sdf_values_np

    ps.render_implicit_surface("SDF Surface", model_sdf, mode="sphere_march", enabled=True, subsample_factor=2)

#TODO : change camera parameters
polyscope_sdf(model)


loaded model


In [16]:
ps.show()

In [17]:
# import numpy as np
# import trimesh
# import pyrender
# from skimage.measure import marching_cubes


# # Create a grid of points
# grid_size = 64  # Resolution of the grid
# x = np.linspace(-4, 4, grid_size)
# y = np.linspace(-4, 4, grid_size)
# z = np.linspace(-4, 4, grid_size)
# X, Y, Z = np.meshgrid(x, y, z, indexing="ij")

# grid_x = X.flatten()
# grid_y = Y.flatten()
# grid_z = Z.flatten()

# grid = np.stack([grid_x, grid_y, grid_z], axis=1)
# grid = torch.tensor(grid, dtype=torch.float64, device=device)

# # Compute SDF values on the grid
# sdf_values = model(grid)

# #sdf_values should be a 3d numpy array
# sdf_values = sdf_values.detach().cpu().numpy().reshape(grid_size, grid_size, grid_size)

# # Extract mesh using Marching Cubes
# vertices, faces, _, _ = marching_cubes(sdf_values, level=0)

# # Create a mesh
# mesh = trimesh.Trimesh(vertices, faces)

# # Render with Pyrender
# scene = pyrender.Scene()
# mesh_pyrender = pyrender.Mesh.from_trimesh(mesh)
# scene.add(mesh_pyrender)

# # Set up the renderer
# viewer = pyrender.Viewer(scene, use_raymond_lighting=True)


In [18]:
def upsampling_vectorized(sites, model):
    sdf_values = model(sites)
    sites_np = sites.detach().cpu().numpy()
    # Compute Voronoi diagram
    vor = Voronoi(sites_np)
    
    neighbors = torch.tensor(np.array(vor.ridge_points), device=device)
    
    # Extract the SDF values for each site in the pair
    sdf_i = sdf_values[neighbors[:, 0]]  # First site in each pair
    sdf_j = sdf_values[neighbors[:, 1]]  # Second site in each pair
    # Find the indices where SDF values have opposing signs or one is zero
    mask_zero_crossing_sites = (sdf_i * sdf_j <= 0).squeeze()
    sites_to_upsample = torch.unique(neighbors[mask_zero_crossing_sites].view(-1))
    
    print("Sites to upsample ",sites_to_upsample.shape)
    
    tet_centroids = sites[sites_to_upsample]

    # Tetrahedron relative positions (unit tetrahedron)
    basic_tet_1 = torch.tensor([[1, 1, 1]], device=device, dtype=torch.float64)
    basic_tet_1 = basic_tet_1.repeat(len(tet_centroids), 1)
    basic_tet_2 = torch.tensor([-1, -1, 1], device=device, dtype=torch.float64)    
    basic_tet_2 = basic_tet_2.repeat(len(tet_centroids), 1)
    basic_tet_3 = torch.tensor([-1, 1, -1], device=device, dtype=torch.float64)    
    basic_tet_3 = basic_tet_3.repeat(len(tet_centroids), 1)
    basic_tet_4 = torch.tensor([1, -1, -1], device=device, dtype=torch.float64)
    basic_tet_4 = basic_tet_4.repeat(len(tet_centroids), 1)


    #compute scale based on cell volume
    centroids = torch.tensor(np.array([vor.vertices[vor.regions[vor.point_region[i]]].mean(axis=0) for i in range(len(sites_np))]), device=device)
    #centroids = torch.tensor(np.array(centroids), device=sites.device, dtype=sites.dtype)
    cells_vertices = [vor.vertices[vor.regions[vor.point_region[i]]] for i in range(len(sites_np))]

    #compute the distance between each centroid  and each vertex in cells_vertices row
    distances = []
    for i in range(len(cells_vertices)):
        min_dist = 100000000000
        for j in range(len(cells_vertices[i])):
            dist = torch.norm(centroids[i] - torch.tensor(cells_vertices[i][j], device=device), p=2)
            if dist < min_dist:
                min_dist = dist
        distances.append(min_dist)
    distances = torch.tensor(distances, device=device)
 
    
    scale = distances[sites_to_upsample] / 2
    
    scale = scale.unsqueeze(1)
    
    
    new_sites = torch.cat((tet_centroids + basic_tet_1 * scale, tet_centroids + basic_tet_2 * scale, tet_centroids + basic_tet_3 * scale, tet_centroids + basic_tet_4 * scale), dim=0)

    updated_sites = torch.cat((sites, new_sites), dim=0)

    return updated_sites
                



In [19]:
cvt_loss_values = []
min_distance_loss_values = []
edge_smoothing_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
zero_target_points_loss_values = []
loss_values = []

def autograd(sites, model, max_iter=100, stop_train_threshold=1e-6, upsampling=0, lambda_weights = [0.1,1.0,0.1,0.1,1.0,1.0,0.1]):
    optimizer = torch.optim.Adam([
    {'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
    {'params': [sites], 'lr': lr_sites}
], betas=(0.5, 0.999))

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0.0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    lambda_sdf = lambda_weights[1]
    lambda_min_distance = lambda_weights[2]
    lambda_laplace = lambda_weights[3]
    lamda_chamfer = lambda_weights[4]
    lamda_eikonal = lambda_weights[5]
    lambda_domain_restriction = lambda_weights[6]
    lambda_target_points = lambda_weights[7]
    
    best_sites = sites.clone()
    best_sites.best_loss = best_loss
    
    while epoch <= max_iter:
        optimizer.zero_grad()
        
        vertices_to_compute, bisectors_to_compute = su.compute_zero_crossing_vertices_3d(sites, model)
        vertices = su.compute_vertices_3d_vectorized(sites, vertices_to_compute)    
        bisectors = su.compute_all_bisectors_vectorized(sites, bisectors_to_compute)
        #combine vertices and bisectors to one tensor for chamfer
        points = torch.cat((vertices, bisectors), 0)


        # Compute losses       
        cvt_loss = lf.compute_cvt_loss_vectorized(sites, model)
        #min_distance_loss = min_distance_regularization_for_op_sites(edges,sites)
        #min_distance_loss = lf.sdf_weighted_min_distance_loss(model, sites)
        #edge_smoothing_loss = compute_edge_smoothing_loss(edges, sites, model)
        chamfer_loss = lf.chamfer_distance(target_points, points)
        eikonal_loss = lf.eikonal(model, input_dimensions=input_dims)
        #domain_restriction_loss = lf.domain_restriction(target_points, model)
        
        sdf_values_target_points = model(target_points)[:,0]
        zero_target_points_loss_L2 = torch.mean(sdf_values_target_points**2)
        zero_target_points_loss_L1 = torch.mean(torch.abs(model(target_points)[:, 0]))
        lambda_1, lambda_2 = 0 , 0.99  # Adjust weights as needed
        zero_target_points_loss = lambda_1 * zero_target_points_loss_L1 + lambda_2 * zero_target_points_loss_L2

               
        # Track raw losses (unweighted)
        cvt_loss_values.append(cvt_loss.item())
        #min_distance_loss_values.append(min_distance_loss.item())
        #edge_smoothing_loss_values.append(edge_smoothing_loss.item())
        chamfer_distance_loss_values.append(chamfer_loss.item())
        eikonal_loss_values.append(eikonal_loss.item())
        #domain_restriction_loss_values.append(domain_restriction_loss.item())
        zero_target_points_loss_values.append(zero_target_points_loss.item())
  
        loss = (
            lambda_cvt * cvt_loss +
            #lambda_min_distance * min_distance_loss + 
            #lambda_laplace * edge_smoothing_loss +
            lamda_chamfer * chamfer_loss +
            lamda_eikonal * eikonal_loss +
            #lambda_domain_restriction * domain_restriction_loss +
            lambda_target_points * zero_target_points_loss
        )
        loss_values.append(loss.item())
        print(f"Epoch {epoch}: loss = {loss.item()}")
                
        loss.backward()
        optimizer.step()
        
        if loss.item() < best_loss:
            best_loss = loss.item()
            best_epoch = epoch
            best_sites = sites.clone()
            best_sites.best_loss = best_loss
            if upsampled > 0:
                print(f"UPSAMPLED {upsampled} Best Epoch {best_epoch}: Best loss = {best_loss}")
                #return best_sites
        
        if abs(prev_loss - loss.item()) < stop_train_threshold:
            print(f"Converged at epoch {epoch} with loss {loss.item()}")
            #break
        
        prev_loss = loss.item() 
        
        if epoch>100 and (epoch // 100) == upsampled+1 and loss.item() < 0.5 and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ",len(sites))
            
            #new_sites = su.upsampling_inside(best_sites, model)
            #new_sites = su.adaptive_density_upsampling(best_sites, model)
            
            #sites = su.add_upsampled_sites(best_sites, new_sites)
            
            sites = upsampling_vectorized(sites, model)
            
            sites = sites.detach().requires_grad_(True)
            #print("upsampled sites length: ",len(sites))
            
            #best_sites = sites.clone()
            #best_sites.best_loss = best_loss
            
            optimizer = torch.optim.Adam([{'params': [p for _, p in model.named_parameters()], 'lr': lr_model},
                                          {'params': [sites], 'lr': lr_sites}])
            upsampled += 1.0
            print("sites length AFTER: ",len(sites))
            
          
        if epoch % (max_iter/10) == 0:
            print(f"Epoch {epoch}: loss = {loss.item()}")
            print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            #save model and sites
            site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
            model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
            torch.save(model.state_dict(), model_file_path)
            torch.save(sites, site_file_path)
            
        
        epoch += 1           
        
    return best_sites

In [20]:

lambda_weights = [0.2,0,0.2,0,1.00011111111101101000101,0.1,0,2]

lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lamda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_target_points = lambda_weights[7]

max_iter = 500

site_file_path = f'{destination}{mesh[0]}{max_iter}3d_sites_{num_centroids}_chamfer{lamda_chamfer}.npy'
#check if optimized sites file exists
if os.path.exists(site_file_path):
    #import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)    
else:
    import cProfile, pstats
    import time
    profiler = cProfile.Profile()
    profiler.enable()
    
    sites = autograd(sites, model, max_iter=max_iter, upsampling=0, lambda_weights=lambda_weights)
    
    profiler.disable()
    stats = pstats.Stats(profiler).sort_stats('cumtime')
    stats.print_stats()
    stats.dump_stats(f'{destination}{mesh[0]}{max_iter}_3d_profile_{num_centroids}_chamfer{lamda_chamfer}.prof')
    
    
    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)
    

print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))
ps_cloud = ps.register_point_cloud("best_optimized_cvt_grid",sites.detach().cpu().numpy())
    
lim=torch.abs(torch.max(sites)).detach().cpu().numpy()*1.1
#plot_voronoi_3d(sites,lim,lim,lim)

cvt_loss:  tensor(0.4153, grad_fn=<MeanBackward0>)
Epoch 0: loss = 7.620295848840643
Epoch 0: loss = 7.620295848840643
Best Epoch 0: Best loss = 7.620295848840643
cvt_loss:  tensor(0.4141, grad_fn=<MeanBackward0>)
Epoch 1: loss = 8.582582808141012
cvt_loss:  tensor(0.3454, grad_fn=<MeanBackward0>)
Epoch 2: loss = 3.722546308996204
cvt_loss:  tensor(0.3736, grad_fn=<MeanBackward0>)
Epoch 3: loss = 3.489603259257034
cvt_loss:  tensor(0.3462, grad_fn=<MeanBackward0>)
Epoch 4: loss = 2.522196314639081
cvt_loss:  tensor(0.3288, grad_fn=<MeanBackward0>)
Epoch 5: loss = 2.571463159578138
cvt_loss:  tensor(0.2980, grad_fn=<MeanBackward0>)
Epoch 6: loss = 2.297674042882702
cvt_loss:  tensor(0.2714, grad_fn=<MeanBackward0>)
Epoch 7: loss = 2.86603488587313
cvt_loss:  tensor(0.2766, grad_fn=<MeanBackward0>)
Epoch 8: loss = 2.7191242867769154
cvt_loss:  tensor(0.2551, grad_fn=<MeanBackward0>)
Epoch 9: loss = 2.2409177287968083
cvt_loss:  tensor(0.2574, grad_fn=<MeanBackward0>)
Epoch 10: loss = 2.0

KeyboardInterrupt: 

In [21]:
epoch = 100

model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
sites = torch.load(site_file_path)
sites_np = sites.detach().cpu().numpy()
model.load_state_dict(torch.load(model_file_path))
#
polyscope_sdf(model)
#
print("model", model_file_path)
print("sites", site_file_path)
ps_cloud = ps.register_point_cloud(f"{epoch} epoch_cvt_grid",sites_np)



model ./images/autograd/3D/bunny500_100_3d_model_4096_chamfer1.000111111111011.pth
sites ./images/autograd/3D/bunny500_100_3d_sites_4096_chamfer1.000111111111011.pth


In [22]:
final_mesh = su.get_zero_crossing_mesh_3d(sites, model)

ps.register_surface_mesh("Zero-Crossing faces", final_mesh[0], final_mesh[1])
ps.register_point_cloud("Mesh vertices", final_mesh[0])

ps.show()



In [None]:
def export_visualisation_3d():
    import imageio
    img_buffer_mesh = []
    img_buffer_model = []
    for i in range(int(max_iter/10)+1):
        epoch = i*int(max_iter/10)
        
        site_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lamda_chamfer}.pth'
        model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lamda_chamfer}.pth'
        if os.path.exists(site_file_path) and os.path.exists(model_file_path):
            print("importing sites and model")
        else:
            print("files not found")
            continue
        print("mesh of epoch: ", epoch)
        
        model.load_state_dict(torch.load(model_file_path))
    
        current_mesh = su.get_zero_crossing_mesh_3d(torch.load(site_file_path), model)
        ps.remove_all_structures()
        ps.register_surface_mesh("Zero-Crossing faces", current_mesh[0], current_mesh[1])
        ps.register_point_cloud("Mesh vertices", current_mesh[0])
        img_buffer_mesh.append(ps.screenshot_to_buffer(transparent_bg=False))
        
        ps.remove_all_structures()
        polyscope_sdf(model)
        img_buffer_model.append(ps.screenshot_to_buffer(transparent_bg=False))


    imageio.mimsave(f'{destination}{max_iter}_3d_{num_centroids}_optimization_mesh.gif',img_buffer_mesh, fps=1, duration=1, loop=0)
    imageio.mimsave(f'{destination}{max_iter}_3d_{num_centroids}_optimization_sdf.gif', img_buffer_model, fps=1, duration=1, loop=0)

export_visualisation_3d()