In [None]:
import torch

def compute_sdf_gradient_per_tetrahedron(sites, tets, sdf_values):
    """
    Computes ∇sdf for each tetrahedron.

    Args:
        sites: (N, 3) tensor of 3D point positions.
        tets: (M, 4) tensor of indices into sites, defining each tetrahedron.
        sdf_values: (N,) tensor of SDF values at each vertex.

    Returns:
        gradients: (M, 3) tensor of SDF gradients per tetrahedron.
    """
    # Get the coordinates of the 4 sites of each tetrahedron
    tet_pts = sites[tets]  # (M, 4, 3)
    v0, v1, v2, v3 = tet_pts[:, 0], tet_pts[:, 1], tet_pts[:, 2], tet_pts[:, 3]

    # Build the matrix of edge vectors for each tet
    Dm = torch.stack([v1 - v0, v2 - v0, v3 - v0], dim=-1)  # (M, 3, 3)

    # Invert Dm to get gradients of barycentric coordinates
    Dm_inv = torch.inverse(Dm)  # (M, 3, 3)

    # The gradient of barycentric coordinates w0 to w3
    grad_w1 = Dm_inv[:, :, 0]  # ∇w1
    grad_w2 = Dm_inv[:, :, 1]  # ∇w2
    grad_w3 = Dm_inv[:, :, 2]  # ∇w3
    grad_w0 = -grad_w1 - grad_w2 - grad_w3  # since ∑wi = 1

    # Stack gradients (M, 4, 3)
    grad_w = torch.stack([grad_w0, grad_w1, grad_w2, grad_w3], dim=1)

    # Get SDF values at tetrahedron sites (M, 4)
    sdf_tet = sdf_values[tets]  # (M, 4)

    # Compute SDF gradient per tet via weighted sum of grad_w * sdf
    sdf_grad = (grad_w * sdf_tet.unsqueeze(-1)).sum(dim=1)  # (M, 3)

    return sdf_grad


In [None]:
import torch

def compute_sdf_gradient_from_tets(sites, tets, sdf_values):
    """
    Computes per-tetrahedron spatial SDF gradient using the least-squares method
    analogous to the CUDA implementation.

    Args:
        sites:      (N, 3) tensor of 3D point coordinates.
        tets:       (M, 4) tensor of indices of tetrahedra (into `sites`).
        sdf_values: (N,) tensor of SDF values at each site.

    Returns:
        gradients:  (N, 3) tensor of SDF gradients accumulated at each site.
        weights:    (N,) tensor of total volume weight per site.
    """
    device = sites.device
    M = tets.shape[0]
    
    # Get positions and sdf values per tetrahedron: (M, 4, 3) and (M, 4)
    tet_pos = sites[tets]                  # (M, 4, 3)
    tet_sdf = sdf_values[tets]            # (M, 4)

    # Compute center point and center sdf of each tet
    center_pos = tet_pos.mean(dim=1)      # (M, 3)
    center_sdf = tet_sdf.mean(dim=1)      # (M,)

    # Compute dX: difference between each vertex and the center
    dX = tet_pos - center_pos[:, None, :]   # (M, 4, 3)

    # Construct G = dX^T dX (M, 3, 3)
    G = torch.einsum('mvi,mvj->mij', dX, dX)  # Gram matrix

    # Compute G inverse (M, 3, 3)
    G_inv = torch.inverse(G)

    # Compute Weights: (M, 4, 3)
    weights = torch.einsum('mij,mvj->mvi', G_inv, dX)

    # Compute per-tet sdf gradient: (M, 3)
    sdf_centered = tet_sdf - center_sdf[:, None]  # (M, 4)
    grad_sdf = torch.einsum('mvi,mv->mi', weights, sdf_centered)

    # Compute tetrahedron volumes for weighting: (M,)
    def volume4(a, b, c, d):
        return torch.abs(torch.einsum('mi,mi->m', torch.cross(b - a, c - a), d - a)) / 6.0

    vol = volume4(tet_pos[:, 0], tet_pos[:, 1], tet_pos[:, 2], tet_pos[:, 3])  # (M,)

    # Accumulate weighted gradient per site
    N = sites.shape[0]
    grad_accum = torch.zeros((N, 3), device=device)
    weight_accum = torch.zeros((N,), device=device)

    for i in range(4):
        indices = tets[:, i]  # (M,)
        grad_accum.index_add_(0, indices, grad_sdf * vol[:, None])
        weight_accum.index_add_(0, indices, vol)

    # Normalize by total volume weights
    grad_per_site = grad_accum / (weight_accum[:, None] + 1e-12)

    return grad_per_site, weight_accum


In [None]:
import torch

def sdf_grad_wrt_positions(sites, tets, sdf_values):
    """
    Computes the gradient of the SDF gradient ∇phi with respect to the tetrahedron site positions.

    Args:
        sites:       (N, 3) tensor of site positions (x0, x1, ..., xN)
        tets:        (M, 4) indices of tetrahedron vertices into `sites`
        sdf_values:  (N,)  tensor of SDF values at the sites

    Returns:
        grad_phi:    (M, 3) spatial gradient of phi inside each tetrahedron
        jacobians:   (M, 4, 3, 3) ∂∇phi / ∂xi for each tetrahedron and each vertex
    """
    device = sites.device
    eps = 1e-8
    x = sites[tets]               # (M, 4, 3)
    phi = sdf_values[tets]       # (M, 4)

    x0, x1, x2, x3 = x[:, 0], x[:, 1], x[:, 2], x[:, 3]
    phi0, phi1, phi2, phi3 = phi[:, 0], phi[:, 1], phi[:, 2], phi[:, 3]

    # Matrix D = [x1-x0, x2-x0, x3-x0] for each tet
    D = torch.stack([x1 - x0, x2 - x0, x3 - x0], dim=-1)  # (M, 3, 3)
    DinvT = torch.inverse(D).transpose(1, 2)             # (M, 3, 3)
    delta_phi = torch.stack([phi1 - phi0, phi2 - phi0, phi3 - phi0], dim=-1)  # (M, 3)

    grad_phi = torch.bmm(DinvT, delta_phi.unsqueeze(-1)).squeeze(-1)  # (M, 3)

    # Compute Jacobians manually: ∂∇phi / ∂xi ∈ (M, 4, 3, 3)
    # Only ∂DinvT/∂xi matters; ∆phi is constant w.r.t. x
    jacobians = torch.zeros((tets.shape[0], 4, 3, 3), device=device)

    for i in range(4):
        # Use finite differences to approximate Jacobian ∂∇phi / ∂x_i
        # since symbolic expression is complex
        perturb = torch.zeros_like(x)
        perturb[:, i] = eps

        D_eps = torch.stack([x1 + perturb[:, 1] - (x0 + perturb[:, 0]),
                             x2 + perturb[:, 2] - (x0 + perturb[:, 0]),
                             x3 + perturb[:, 3] - (x0 + perturb[:, 0])], dim=-1)

        DinvT_eps = torch.inverse(D_eps).transpose(1, 2)
        grad_phi_eps = torch.bmm(DinvT_eps, delta_phi.unsqueeze(-1)).squeeze(-1)

        jacobians[:, i] = (grad_phi_eps - grad_phi) / eps

    return grad_phi, jacobians


In [None]:
import torch

def compute_sdf_gradient_autograd(sites, tets, sdf_values):
    """
    Computes ∇phi per tetrahedron using PyTorch autograd.

    Args:
        sites:      (N, 3) tensor of vertex positions, requires_grad=True if you want to optimize them.
        tets:       (M, 4) tensor of tetrahedron indices.
        sdf_values: (N,) tensor of SDF values, can also require_grad=True if optimizing.

    Returns:
        grad_phi:   (M, 3) tensor of spatial gradients ∇phi in each tetrahedron.
    """
    x = sites[tets]               # (M, 4, 3)
    phi = sdf_values[tets]       # (M, 4)

    x0, x1, x2, x3 = x[:, 0], x[:, 1], x[:, 2], x[:, 3]
    phi0, phi1, phi2, phi3 = phi[:, 0], phi[:, 1], phi[:, 2], phi[:, 3]

    # Build D = [x1 - x0, x2 - x0, x3 - x0] for each tetrahedron
    D = torch.stack([x1 - x0, x2 - x0, x3 - x0], dim=-1)  # (M, 3, 3)
    DinvT = torch.inverse(D).transpose(1, 2)             # (M, 3, 3)

    delta_phi = torch.stack([phi1 - phi0, phi2 - phi0, phi3 - phi0], dim=-1)  # (M, 3)
    grad_phi = torch.bmm(DinvT, delta_phi.unsqueeze(-1)).squeeze(-1)         # (M, 3)

    return grad_phi
