In [1]:
import kaolin
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
import polyscope as ps
import diffvoronoi
import sdfpred_utils.sdfpred_utils as su
import sdfpred_utils.loss_functions as lf
from pytorch3d.loss import chamfer_distance
from pytorch3d.ops import knn_points, knn_gather
import torch
from torch import nn

# cuda devices
device = torch.device("cuda:0")
print("Using device: ", torch.cuda.get_device_name(device))

input_dims = 3
lr_sites = 0.005
# lr_model = 0.00001
destination = "./images/autograd/End2End_DCCVT_interpolSDF/"
model_trained_it = ""

# mesh = ["sphere"]

# mesh = ["gargoyle", "/home/wylliam/dev/Kyushu_experiments/data/gargoyle"]
# trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-04-24-18-16-03/gargoyle/gargoyle/trained_models/model{model_trained_it}.pth"


#
mesh = ["chair", "/home/wylliam/dev/Kyushu_experiments/data/chair"]
trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-05-02-17-56-25/chair/chair/trained_models/model{model_trained_it}.pth"

# mesh = ["bunny", "/home/wylliam/dev/Kyushu_experiments/data/bunny"]
# trained_model_path = f"/home/wylliam/dev/HotSpot/log/3D/pc/HotSpot-all-2025-04-25-17-32-49/bunny/bunny/trained_models/model{model_trained_it}.pth"


Using device:  NVIDIA GeForce RTX 3090


In [2]:
num_centroids = 16**3
grid = 32  # 128
print("Creating new sites")
noise_scale = 0.005
domain_limit = 1
x = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids ** (1 / 3))))
y = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids ** (1 / 3))))
z = torch.linspace(-domain_limit, domain_limit, int(round(num_centroids ** (1 / 3))))
meshgrid = torch.meshgrid(x, y, z)
meshgrid = torch.stack(meshgrid, dim=3).view(-1, 3)

torch.manual_seed(69)
# add noise to meshgrid
meshgrid += torch.randn_like(meshgrid) * noise_scale


sites = meshgrid.to(device, dtype=torch.float32).requires_grad_(True)

print("Sites shape: ", sites.shape)
print("Sites: ", sites[0])
ps.init()


Creating new sites
Sites shape:  torch.Size([4096, 3])
Sites:  tensor([-1.0027, -1.0065, -0.9978], device='cuda:0', grad_fn=<SelectBackward0>)
[polyscope] Backend: openGL3_glfw -- Loaded openGL version: 3.3.0 NVIDIA 575.64.03


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
vs = su.octahedral_grid_points(grid=8, domain=(-1.0, 1.0))
print("Octahedral grid points shape: ", vs.shape)

vs = vs + torch.randn_like(vs) * noise_scale

# sites = vs.to(device, dtype=torch.float32).requires_grad_(True)

# ps_cloud = ps.register_point_cloud("vs",vs.detach().cpu().numpy())
# ps.show()


Octahedral grid points shape:  torch.Size([1728, 3])


In [4]:
# LOAD MODEL WITH HOTSPOT
import sys

sys.path.append("3rdparty/HotSpot")
from dataset import shape_3d
import models.Net as Net

loss_type = "igr_w_heat"
loss_weights = [350, 0, 0, 1, 0, 0, 20]

train_set = shape_3d.ReconDataset(
    file_path=mesh[1] + ".ply",
    n_points=grid * grid * 150,  # 15000, #args.n_points,
    n_samples=10001,  # args.n_iterations,
    grid_res=256,  # args.grid_res,
    grid_range=1.1,  # args.grid_range,
    sample_type="uniform_central_gaussian",  # args.nonmnfld_sample_type,
    sampling_std=0.5,  # args.nonmnfld_sample_std,
    n_random_samples=7500,  # args.n_random_samples,
    resample=True,
    compute_sal_dist_gt=(True if "sal" in loss_type and loss_weights[5] > 0 else False),
    scale_method="mean",  # "mean" #args.pcd_scale_method,
)

model = Net.Network(
    latent_size=0,  # args.latent_size,
    in_dim=3,
    decoder_hidden_dim=128,  # args.decoder_hidden_dim,
    nl="sine",  # args.nl,
    encoder_type="none",  # args.encoder_type,
    decoder_n_hidden_layers=5,  # args.decoder_n_hidden_layers,
    neuron_type="quadratic",  # args.neuron_type,
    init_type="mfgi",  # args.init_type,
    sphere_init_params=[1.6, 0.1],  # args.sphere_init_params,
    n_repeat_period=30,  # args.n_repeat_period,
)
model.to(device)

######
test_dataloader = torch.utils.data.DataLoader(train_set, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)
test_data = next(iter(test_dataloader))
mnfld_points = test_data["mnfld_points"].to(device)

# add noise to mnfld_points
# mnfld_points += torch.randn_like(mnfld_points) * noise_scale * 2

mnfld_points.requires_grad_()
print("mnfld_points shape: ", mnfld_points.shape)
if torch.cuda.is_available():
    map_location = torch.device("cuda")
else:
    map_location = torch.device("cpu")
model.load_state_dict(torch.load(trained_model_path, weights_only=True, map_location=map_location))


# def sphere_sdf(points: torch.Tensor, center: torch.Tensor, radius: float) -> torch.Tensor:
#     """
#     Compute the SDF of a sphere at given 3D points.

#     Args:
#         points: (N, 3) tensor of 3D query points
#         center: (3,) tensor specifying the center of the sphere
#         radius: float, radius of the sphere

#     Returns:
#         sdf: (N,) tensor of signed distances
#     """
#     return torch.norm(points - center, dim=-1) - radius


# # generate points on the sphere
# mnfld_points = torch.randn(grid * grid * 150, 3, device=device)
# mnfld_points = mnfld_points / torch.norm(mnfld_points, dim=-1, keepdim=True) * 0.5
# mnfld_points = mnfld_points.unsqueeze(0).requires_grad_()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.
mnfld_points shape:  torch.Size([1, 153600, 3])


  return self.fget.__get__(instance, owner)()


<All keys matched successfully>

In [5]:
##add mnfld points with random noise to sites
# N = mnfld_points.squeeze(0).shape[0]
# num_samples = 24**3 - (num_centroids)
# idx = torch.randint(0, N, (num_samples,))
# sampled = mnfld_points.squeeze(0)[idx]
# perturbed = sampled + (torch.rand_like(sampled)-0.5)*0.05
# sites = torch.cat((sites, perturbed), dim=0)

# make sites a leaf tensor
sites = sites.detach().requires_grad_()
print(sites.dtype)
print(sites.shape)
print(f"Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")

sdf0 = model(sites)

# sdf0 = sphere_sdf(sites, torch.zeros(3).to(device), 0.50)
##sdf0 += torch.randn_like(sdf0) * noise_scale/2

sdf0 = sdf0.detach().squeeze(-1).requires_grad_()


print(sdf0.shape)
print(sdf0.is_leaf)

# print(sdf_grad0.shape)
# print(sdf_grad0.is_leaf)


torch.float32
torch.Size([4096, 3])
Allocated: 2.694656 MB, Reserved: 23.068672 MB
torch.Size([4096])
True


In [6]:
sites_np = sites.detach().cpu().numpy()
d3dsimplices = diffvoronoi.get_delaunay_simplices(sites_np.reshape(input_dims * sites_np.shape[0]))
d3dsimplices = np.array(d3dsimplices)

print("sites shape: ", sites.shape)

ps_cloud = ps.register_point_cloud("initial_cvt_grid+pc_gt", sites.detach().cpu().numpy(), enabled=False)
ps_cloud.add_scalar_quantity(
    "vis_grid_pred",
    sdf0.detach().cpu().numpy(),
    enabled=True,
    cmap="coolwarm",
    vminmax=(-0.00005, 0.00005),
)
mnf_cloud = ps.register_point_cloud("mnfld_points_pred", mnfld_points.squeeze(0).detach().cpu().numpy(), enabled=False)

v_vect, f_vect, sdf_verts, sdf_verts_grads, _ = su.get_clipped_mesh_numba(sites, None, d3dsimplices, False, sdf0, True)
ps_mesh = ps.register_surface_mesh(
    "sdf unclipped initial mesh",
    v_vect.detach().cpu().numpy(),
    f_vect,
    back_face_policy="identical",
)
ps_vert = ps.register_point_cloud("sdf unclipped initial verts", v_vect.detach().cpu().numpy(), enabled=False)

v_vect, f_vect, sdf_verts, sdf_verts_grads, tet_probs = su.get_clipped_mesh_numba(
    sites, None, d3dsimplices, True, sdf0, True
)
ps_mesh = ps.register_surface_mesh(
    "sdf clipped initial mesh",
    v_vect.detach().cpu().numpy(),
    f_vect,
    back_face_policy="identical",
)
ps_cloud = ps.register_point_cloud("active sites", tet_probs[2].reshape(-1, 3).detach().cpu().numpy(), enabled=False)
ps_cloud.add_vector_quantity("site step dir", tet_probs[0].reshape(-1, 3).detach().cpu().numpy())
# ps_vert.add_vector_quantity("verts step dir", tet_probs[1].detach().cpu().numpy())


ps.show()

sites shape:  torch.Size([4096, 3])


In [7]:
# # def eikonal_loss(grad_sdf: torch.Tensor) -> torch.Tensor:
# #     """
# #     Eikonal regularization loss.

# #     Args:
# #         grad_sdf: Tensor of shape (N, 3) containing ∇φ at each site.
# #         variant: 'a' for E1a: ½ mean((||∇φ|| - 1)²)
# #     Returns:
# #         A scalar tensor containing the eikonal loss.
# #     """
# #     norms = torch.norm(grad_sdf, dim=1)  # (N,)
# #     loss = 0.5 * torch.mean((norms**2 - 1.0) ** 2)
# #     return loss


# def motion_by_mean_curvature_loss(
#     sdf: torch.Tensor, grad_sdf: torch.Tensor, sites: torch.Tensor, d3dsimplices: torch.Tensor, factor: float = 1.5
# ) -> torch.Tensor:
#     """
#     Motion-by-mean-curvature smoothing loss via a smeared Heaviside function.

#     Args:
#         sdf: Tensor of shape (N,) containing φ at each site.
#         grad_sdf: Tensor of shape (N, 3) containing ∇φ at each site.
#         epsilon_H: Bandwidth ε_H for the smearing (e.g., 1.5 * average edge length).

#     Returns:
#         A scalar tensor containing the smoothing loss.
#     """
#     # compute epsion_H from sites and d3dsimplices
#     d3d = torch.tensor(d3dsimplices).to(device).detach()  # (M,4)
#     comb = torch.combinations(torch.arange(d3d.shape[1], device=device), r=2)  # (6,2)
#     edges = d3d[:, comb]  # (M,6,2)
#     edges = edges.reshape(-1, 2)  # (M*6,2)
#     edges, _ = torch.sort(edges, dim=1)  # sort each row so (a,b) == (b,a)
#     unique_edges = torch.unique(edges, dim=0)
#     v0, v1 = sites[unique_edges[:, 0]], sites[unique_edges[:, 1]]  # (N,3)
#     i, j = unique_edges[:, 0], unique_edges[:, 1]  # (N,3)

#     phi = sdf
#     sign_mask = phi[i] * phi[j] < 0
#     v0, v1 = v0[sign_mask], v1[sign_mask]  # only keep edges with opposite signs
#     edge_lengths = torch.norm(v1 - v0, dim=1)  # (N,)
#     epsilon_H = factor * torch.mean(edge_lengths)  # Bandwidth for the smeared Heaviside function

#     # Compute the derivative of the smeared Heaviside H'(φ)
#     mask = torch.abs(phi) <= epsilon_H
#     H_prime = torch.zeros_like(phi)
#     # H'(φ) = (1/(2ε_H)) * (1 + cos(π φ / ε_H)) for |φ| ≤ ε_H
#     H_prime[mask] = (1.0 / (2.0 * epsilon_H)) * (1.0 + torch.cos(np.pi * phi[mask] / epsilon_H))

#     # Compute |∇H| = |H'(φ)| * ||∇φ||
#     norms = torch.norm(grad_sdf, dim=1)
#     magnitude = H_prime * norms

#     # Ignore very small contributions (tetrahedra already smooth enough)
#     valid = magnitude > 1e-8
#     if valid.any():
#         return torch.mean(magnitude[valid])
#     else:
#         return torch.tensor(0.0, device=sdf.device)


# def smoothness_loss(sites, phi, d3dsimplices, eps=1e-8, weighted=True):
#     """
#     Compute ∑_{(i,j)} w_ij (φ_i - φ_j)^2 over edges from tetrahedral simplices.

#     Args:
#         sites:     (N,3) float tensor of point positions.
#         phi:       (N,)   float tensor of SDF values at sites.
#         simplices: (M,4)  long tensor of tetrahedron indices.
#         eps:       small constant to avoid divide-by-zero.
#         weighted:  if True, weight edges by 1/dist else w=1.
#     Returns:
#         scalar smoothing loss.
#     """
#     # 1) extract all simplex edges
#     d3d = torch.tensor(d3dsimplices).to(device).detach()  # (M,4)
#     comb = torch.combinations(torch.arange(d3d.shape[1], device=sites.device), r=2)  # (6,2)
#     edges = d3d[:, comb]  # (M,6,2)
#     edges = edges.reshape(-1, 2)  # (M*6,2)
#     edges = torch.sort(edges, dim=1)[0]  # sort lexicographically
#     edges = torch.unique(edges, dim=0)  # keep unique undirected edges
#     i, j = edges[:, 0], edges[:, 1]  # index pairs
#     sign_mask = phi[i] * phi[j] < 0

#     i, j = i[sign_mask], j[sign_mask]  # only keep edges with opposite signs

#     if weighted:
#         dij = (sites[i] - sites[j]).norm(dim=1)  # (E,)
#         w = 1.0 / (dij + eps)
#     else:
#         w = 1.0

#     diff = phi[i] - phi[j]  # (E,)
#     loss = torch.mean(w * diff * diff)
#     return loss


In [8]:
# SITES OPTIMISATION LOOP


cvt_loss_values = []
min_distance_loss_values = []
chamfer_distance_loss_values = []
eikonal_loss_values = []
domain_restriction_loss_values = []
sdf_loss_values = []
div_loss_values = []
loss_values = []

voroloss = lf.Voroloss_opt().to(device)


def train_DCCVT(
    sites,
    sites_sdf,
    max_iter=100,
    stop_train_threshold=1e-6,
    upsampling=0,
    lambda_weights=[0.1, 1.0, 0.1, 0.1, 1.0, 1.0, 0.1],
):
    optimizer = torch.optim.Adam(
        [
            {"params": [sites], "lr": lr_sites * 0.1},
            {"params": [sites_sdf], "lr": lr_sites * 0.1},
        ]
    )
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.95)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1.0)

    # optimizer_sites = torch.optim.Adam([{'params': [sites], 'lr': lr_sites}])
    # optimizer_sdf = torch.optim.SGD([{'params': [sites_sdf], 'lr': lr_sites}])
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[80, 150, 200, 250], gamma=0.5)

    prev_loss = float("inf")
    best_loss = float("inf")
    upsampled = 0.0
    epoch = 0
    lambda_cvt = lambda_weights[0]
    lambda_chamfer = lambda_weights[4]
    best_sites = sites.clone()
    best_sites.best_loss = best_loss

    while epoch <= max_iter:
        # if epoch % 2 == 0:
        #     optimizer = optimizer_sites
        #     clip = False
        # else:
        #     optimizer = optimizer_sdf
        #     clip = True

        optimizer.zero_grad()

        sites_np = sites.detach().cpu().numpy()
        d3dsimplices = diffvoronoi.get_delaunay_simplices(sites_np.reshape(input_dims * sites_np.shape[0]))
        d3dsimplices = np.array(d3dsimplices)

        cvt_loss = lf.compute_cvt_loss_vectorized_delaunay(sites, None, d3dsimplices)

        build_mesh = False
        clip = True

        v_vect, f_vect, sdf_verts, sdf_verts_grads, _ = su.get_clipped_mesh_numba(
            sites, None, d3dsimplices, clip, sites_sdf, build_mesh
        )

        if build_mesh:
            triangle_faces = [[f[0], f[i], f[i + 1]] for f in f_vect for i in range(1, len(f) - 1)]
            triangle_faces = torch.tensor(triangle_faces, device=device)
            hs_p = su.sample_mesh_points_heitz(v_vect, triangle_faces, num_samples=mnfld_points.shape[0])
            chamfer_loss_mesh, _ = chamfer_distance(mnfld_points.detach(), hs_p.unsqueeze(0))
        else:
            chamfer_loss_mesh, _ = chamfer_distance(mnfld_points.detach(), v_vect.unsqueeze(0))

        sites_loss = (
            lambda_cvt / 10 * cvt_loss + lambda_chamfer * chamfer_loss_mesh
            # lambda_chamfer * chamfer_loss_points
            # + lambda_chamfer / 100 * voroloss(mnfld_points.squeeze(0), sites).mean()
        )

        # sites_sdf_grads = su.sdf_space_grad_pytorch_diego(
        #     sites, sites_sdf, torch.tensor(d3dsimplices).to(device).detach()
        # )

        # print norm min max for sites_sdf_grads
        # print("sites_sdf_grads norm min: ", sites_sdf_grads.norm(dim=1).min().item())
        # print("sites_sdf_grads norm max: ", sites_sdf_grads.norm(dim=1).max().item())
        # print("sites_sdf_grads norm mean: ", sites_sdf_grads.norm(dim=1).mean().item())

        # motion_loss = (
        #     lambda_cvt
        #     / 1000
        #     * motion_by_mean_curvature_loss(sites_sdf, sites_sdf_grads, sites, d3dsimplices, factor=1.5)
        # )
        # eik_loss = lambda_cvt / 10 * torch.mean(sites_sdf_grads - 1) ** 2  # * eikonal_loss(sdf_verts_grads)
        # print("eikonal_loss: ", eik_loss.item(), "motion_loss: ", motion_loss.item())
        # sm_loss = lambda_cvt / 100 * smoothness_loss(sites, sites_sdf, d3dsimplices)
        # print("smoothness_loss: ", sm_loss.item())
        # sdf_loss = eik_loss + motion_loss
        # sdf_loss = eik_loss

        loss = sites_loss  # + sdf_loss
        loss_values.append(loss.item())
        print(f"Epoch {epoch}: loss = {loss.item()}")

        # print(f"before loss.backward(): Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")
        loss.backward()
        # print(f"After loss.backward(): Allocated: {torch.cuda.memory_allocated() / 1e6} MB, Reserved: {torch.cuda.memory_reserved() / 1e6} MB")
        print("-----------------")

        # torch.nn.utils.clip_grad_norm_(sites_sdf, 1.0)
        # torch.nn.utils.clip_grad_norm_(sites, 1.0)
        optimizer.step()

        # sites_sdf += (sites_sdf_grads*(sites-sites_positions)).sum(dim=1)

        scheduler.step()
        print("Learning rate: ", optimizer.param_groups[0]["lr"])
        # if epoch>100 and (epoch // 100) == upsampled+1 and loss.item() < 0.5 and upsampled < upsampling:

        if epoch / (max_iter * 0.80) > upsampled / upsampling and upsampled < upsampling:
            print("sites length BEFORE UPSAMPLING: ", len(sites))
            if len(sites) * 1.09 > grid**3:
                print("Skipping upsampling, too many sites, sites length: ", len(sites), "grid size: ", grid**3)
                upsampled = upsampling
                sites = sites.detach().requires_grad_(True)
                sites_sdf = sites_sdf.detach().requires_grad_(True)

                optimizer = torch.optim.Adam(
                    [
                        {"params": [sites], "lr": lr_sites * 0.1},
                        {"params": [sites_sdf], "lr": lr_sites * 0.1},
                    ]
                )
                # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
                continue
            # sites, sites_sdf = su.upsampling_vectorized_sites_sites_sdf(sites, tri=None, vor=None, simplices=d3dsimplices, model=sites_sdf)
            # sites, sites_sdf = su.upsampling_curvature_vectorized_sites_sites_sdf(sites, tri=None, vor=None, simplices=d3dsimplices, model=sites_sdf)
            sites, sites_sdf = su.upsampling_adaptive_vectorized_sites_sites_sdf(
                sites, simplices=d3dsimplices, model=sites_sdf
            )

            # sites, sites_sdf = su.upsampling_chamfer_vectorized_sites_sites_sdf(
            #     sites, d3dsimplices, sites_sdf, mnfld_points
            # )

            sites = sites.detach().requires_grad_(True)
            sites_sdf = sites_sdf.detach().requires_grad_(True)

            optimizer = torch.optim.Adam(
                [
                    {"params": [sites], "lr": lr_sites * 0.1},
                    {"params": [sites_sdf], "lr": lr_sites * 0.1},
                ]
            )
            # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.98)

            upsampled += 1.0
            print("sites shape AFTER: ", sites.shape)
            print("sites sdf shape AFTER: ", sites_sdf.shape)

        if epoch % (max_iter / 10) == 0 or epoch == max_iter:
            # print(f"Epoch {epoch}: loss = {loss.item()}")
            # print(f"Best Epoch {best_epoch}: Best loss = {best_loss}")
            # save model and sites
            # ps.register_surface_mesh(f"{epoch} triangle clipped mesh", v_vect.detach().cpu().numpy(), triangle_faces.detach().cpu().numpy())

            # ps.register_point_cloud('sampled points end', hs_p.detach().cpu().numpy())
            ps.register_point_cloud("sampled points end", v_vect.detach().cpu().numpy(), enabled=False)

            if f_vect is not None:
                ps_mesh = ps.register_surface_mesh(
                    f"{epoch} sdf clipped pmesh",
                    v_vect.detach().cpu().numpy(),
                    f_vect,
                    back_face_policy="identical",
                    enabled=False,
                )
                ps_mesh.add_vector_quantity(
                    f"{epoch} sdf verts grads",
                    sdf_verts_grads.detach().cpu().numpy(),
                    enabled=False,
                )

            site_file_path = (
                f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}.pth"
            )
            # model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lambda_chamfer}.pth'
            sdf_file_path = (
                f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sdf_{num_centroids}_chamfer{lambda_chamfer}.pth"
            )
            torch.save(sites_sdf, sdf_file_path)
            torch.save(sites, site_file_path)

        epoch += 1

    return sites, sites_sdf

In [9]:
# lambda_weights = [252,0,0,0,10.211111,0,100,0]
# lambda_weights = [500,0,0,0,1000,0,100,0]
lambda_weights = [100, 0, 0, 0, 1000, 0, 100, 0]


lambda_cvt = lambda_weights[0]
lambda_sdf = lambda_weights[1]
lambda_min_distance = lambda_weights[2]
lambda_laplace = lambda_weights[3]
lambda_chamfer = lambda_weights[4]
lambda_eikonal = lambda_weights[5]
lambda_domain_restriction = lambda_weights[6]
lambda_true_points = lambda_weights[7]

max_iter = 1000

In [10]:
site_file_path = f"{destination}{max_iter}_cvt_{lambda_cvt}_chamfer_{lambda_chamfer}_eikonal_{lambda_eikonal}.npy"
# check if optimized sites file exists
if not os.path.exists(site_file_path):
    # import sites
    print("Importing sites")
    sites = np.load(site_file_path)
    sites = torch.from_numpy(sites).to(device).requires_grad_(True)
else:
    # import cProfile, pstats
    # import time
    # profiler = cProfile.Profile()
    # profiler.enable()

    # with torch.profiler.profile(activities=[
    #         torch.profiler.ProfilerActivity.CPU,
    #         torch.profiler.ProfilerActivity.CUDA,
    #     ],
    #     record_shapes=False,
    #     with_stack=True  # Captures function calls
    # ) as prof:
    #     sites, optimized_sites_sdf = train_DCCVT(sites, sdf0, offset=None, max_iter=max_iter, upsampling=0, lambda_weights=lambda_weights)

    # print(prof.key_averages().table(sort_by="self_cuda_time_total"))
    # prof.export_chrome_trace("trace.json")

    # #
    sites, optimized_sites_sdf = train_DCCVT(
        sites, sdf0, max_iter=max_iter, upsampling=10, lambda_weights=lambda_weights
    )

    sites_np = sites.detach().cpu().numpy()
    np.save(site_file_path, sites_np)

print("Sites length: ", len(sites))
print("min sites: ", torch.min(sites))
print("max sites: ", torch.max(sites))

Epoch 0: loss = 211.19393920898438
-----------------
Learning rate:  0.0005
Epoch 1: loss = 208.1784210205078
-----------------
Learning rate:  0.0005
sites length BEFORE UPSAMPLING:  4096
tensor(0.1250, device='cuda:0', grad_fn=<MedianBackward0>) tensor(0.1500, device='cuda:0', grad_fn=<MulBackward0>) tensor(0.1100, device='cuda:0', grad_fn=<MulBackward0>)
Hybrid upsampling regime
Number of candidates in hybrid regime: 34
Before upsampling, number of sites: 4096 amount added: 136
sites shape AFTER:  torch.Size([4232, 3])
sites sdf shape AFTER:  torch.Size([4232])
Epoch 2: loss = 206.19044494628906
-----------------
Learning rate:  0.0005
Epoch 3: loss = 205.99549865722656
-----------------
Learning rate:  0.0005
Epoch 4: loss = 204.67477416992188
-----------------
Learning rate:  0.0005
Epoch 5: loss = 204.05130004882812
-----------------
Learning rate:  0.0005
Epoch 6: loss = 203.15631103515625
-----------------
Learning rate:  0.0005
Epoch 7: loss = 202.34732055664062
--------------

In [None]:
epoch = 1000

# model_file_path = f'{destination}{mesh[0]}{max_iter}_{epoch}_3d_model_{num_centroids}_chamfer{lambda_chamfer}.pth'
site_file_path = f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}.pth"
sdf_file_path = f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sdf_{num_centroids}_chamfer{lambda_chamfer}.pth"


sites = torch.load(site_file_path)
sdf_v = torch.load(sdf_file_path)

sites_np = sites.detach().cpu().numpy()
print("sdf", sdf_v.shape)
print("sites", site_file_path)

ps_cloud_f = ps.register_point_cloud(f"{epoch} epoch_cvt_grid", sites_np)
ps_cloud_f.add_scalar_quantity(
    "vis_grid_pred",
    sdf_v.detach().cpu().numpy(),
    enabled=True,
    cmap="coolwarm",
    vminmax=(-0.15, 0.15),
)

print("sites_np shape: ", sites_np.shape)

# print sites if Nan
if np.isnan(sites_np).any():
    print("sites_np contains NaN values")
    print("sites_np NaN values: ", np.isnan(sites_np).sum())
# remove nan values from sites tensor
sites_np = sites_np[~np.isnan(sites_np).any(axis=1)]
sites = torch.from_numpy(sites_np).to(device).requires_grad_(True)

sdf torch.Size([40832])
sites ./images/autograd/End2End_DCCVT_interpolSDF/chair1000_1000_3d_sites_4096_chamfer1000.pth
sites_np shape:  (40832, 3)


In [15]:
# v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, True)
# ps.register_surface_mesh("model final clipped polygon mesh", v_vect.detach().cpu().numpy(), f_vect)

# v_vect, f_vect = su.get_clipped_mesh_numba(sites, model, None, False)
# ps.register_surface_mesh("model final polygon mesh", v_vect.detach().cpu().numpy(), f_vect)

######################################################
(
    v_vect,
    f_vect,
    _,
    _,
    _,
) = su.get_clipped_mesh_numba(sites, None, None, False, sdf_v, True)
ps.register_surface_mesh("sdf final unclipped polygon mesh", v_vect.detach().cpu().numpy(), f_vect)


v_vect, f_vect, _, _, _ = su.get_clipped_mesh_numba(sites, None, None, True, sdf_v, True)
ps.register_surface_mesh("sdf final clipped polygon mesh", v_vect.detach().cpu().numpy(), f_vect)
# f_vect = [[f[0], f[i], f[i + 1]] for f in f_vect for i in range(1, len(f) - 1)]


# export obj file
output_obj_file = (
    f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}_outputmesh.obj"
)
output_ply_file = (
    f"{destination}{mesh[0]}{max_iter}_{epoch}_3d_sites_{num_centroids}_chamfer{lambda_chamfer}_targetpointcloud.ply"
)
su.save_obj(output_obj_file, v_vect.detach().cpu().numpy(), f_vect)
su.save_target_pc_ply(output_ply_file, mnfld_points.squeeze(0).detach().cpu().numpy())

ps.show()

Computing Delaunay simplices...
Computing Delaunay simplices...


In [14]:
# chamfer metric
# add sampled points to polyscope and ground truth mesh to polyscope

import trimesh


def sample_points_on_mesh(mesh_path, n_points=100000):
    mesh = trimesh.load(mesh_path)
    # normalize mesh
    mesh.apply_translation(-mesh.centroid)
    mesh.apply_scale(1.0 / np.max(np.abs(mesh.vertices)))
    # export mesh to obj file
    mesh.export(mesh_path.replace(".obj", ".obj"))
    points, _ = trimesh.sample.sample_surface(mesh, n_points)
    return points, mesh


from trimesh.proximity import ProximityQuery


def point_to_mesh_distance(points, mesh):
    pq = ProximityQuery(mesh)
    dists = pq.signed_distance(points)
    return np.abs(dists)  # signed => unsigned


import numpy as np
from scipy.spatial import cKDTree


def chamfer_accuracy_completeness(ours_pts, gt_pts):
    # Completeness: GT → Ours
    dists_gt_to_ours = cKDTree(ours_pts).query(gt_pts, k=1)[0]
    completeness = np.mean(dists_gt_to_ours**2)

    # Accuracy: Ours → GT
    dists_ours_to_gt = cKDTree(gt_pts).query(ours_pts, k=1)[0]
    accuracy = np.mean(dists_ours_to_gt**2)

    return accuracy, completeness


ours_pts, _ = sample_points_on_mesh(output_obj_file, n_points=100000)
m = mesh[1].replace("data", "mesh")
gt_pts, _ = sample_points_on_mesh(m + ".obj", n_points=100000)

acc, comp = chamfer_accuracy_completeness(ours_pts, gt_pts)

print(f"Chamfer Accuracy (Ours → GT): {acc:.6f}")
print(f"Chamfer Completeness (GT → Ours): {comp:.6f}")
print(f"Chamfer Distance (symmetric): {acc + comp:.6f}")


Chamfer Accuracy (Ours → GT): 0.000177
Chamfer Completeness (GT → Ours): 0.000123
Chamfer Distance (symmetric): 0.000300
