In [1]:
import torch
import numpy as np

from skimage import measure
import trimesh
import open3d as o3d


def make_3d_grid(min_val, max_val, resolution, device):
    coords = torch.linspace(min_val, max_val, resolution, device=device)
    X, Y, Z = torch.meshgrid(coords, coords, coords, indexing='ij')
    return torch.stack([X, Y, Z], dim=-1).reshape(-1, 3)

In [2]:
def implicit_func(selected_pts, shape_id):

    if shape_id=='sphere':
        return sphere_sdf(selected_pts)

    elif shape_id=='urchin':
        return urchin_implicit(selected_pts)

    elif shape_id=='rippley':
        return rippley_implicit(selected_pts)

    elif shape_id=='torus':
        return torus_sdf(selected_pts)

    elif shape_id=='twisted_torus':
        return twisted_torus_implicit(selected_pts)


In [3]:

def sphere_sdf(selected_pts):
    #sdf for a sphere with radius 0.5, centre the origin
    
    sphere_f = selected_pts.pow(2).sum(-1)
    sdf = sphere_f.sqrt() - 0.5
    return sdf


def urchin_implicit(selected_pts, squared=False):
    x,y,z = selected_pts[...,0], selected_pts[...,1], selected_pts[...,2]
    theta = torch.arctan2(y,x)
    result = (x**2 + y**2 + z**2).sqrt() - 0.1* torch.sin(5*theta) * (abs(z**2)-1) - 1.0
    
    return result


def rippley_implicit(selected_pts, squared=False):
    x,y,z = selected_pts[...,0], selected_pts[...,1], selected_pts[...,2]
    theta = torch.arctan2(y,x)
    result = (x**2 + y**2 + z**2).sqrt() - 0.1* torch.sin(15*theta+5*z) * (abs(z**2)-1) - 1.0

    return result


def torus_sdf(selected_pts):
    raise NotImplementedError('You need to write this function yourself.')


def twisted_torus_implicit(selected_pts):
    raise NotImplementedError('You need to write this function yourself.')

In [6]:

shape_id = 'sphere' #'rippley'#'urchin'#
output_filepath = f'data/surfaces/{shape_id}-mc.ply'



my_implicit_func = lambda pts:  implicit_func(
                pts.reshape(1, -1, 3),
                shape_id=shape_id
            )

# --- Configuration ---
res = 50
min_val = -1.0
max_val = 1.0
device = 'cpu'





# 1. Generate Points
points_3d = make_3d_grid(min_val, max_val, res, device)

# 2. Inference
with torch.no_grad():
    sdf_flat = implicit_func(
        points_3d, shape_id=shape_id
    ).cpu().numpy()

sdf_volume = sdf_flat.reshape(res, res, res)

# 3. Marching Cubes WITH SPACING
voxel_size = (max_val - min_val) / (res - 1)

verts, faces, normals, values = measure.marching_cubes(
    sdf_volume,
    level=0.0,
    spacing=(voxel_size, voxel_size, voxel_size)
)


# 4. Shift origin to min_val
# marching_cubes assumes origin at (0,0,0)
verts += np.array([min_val, min_val, min_val])

# 5. Export
mesh = trimesh.Trimesh(
    vertices=verts,
    faces=faces,
    vertex_normals=normals*-1
)
mesh.export(output_filepath)
mesh.show()



# 1. Load the mesh (or convert from trimesh object)
# If you already have `verts` and `faces`:
mesh_o3d = o3d.geometry.TriangleMesh(
    vertices=o3d.utility.Vector3dVector(mesh.vertices),
    triangles=o3d.utility.Vector3iVector(mesh.faces),
)

# Set normals if you have them
mesh_o3d.vertex_normals = o3d.utility.Vector3dVector(normals * -1)

# 2. Set a pale blue color for all vertices
pale_blue = np.array([0.6, 0.8, 1.0])  # RGB in [0,1]
mesh_o3d.vertex_colors = o3d.utility.Vector3dVector(
    np.tile(pale_blue, (len(verts), 1))
)


mesh_o3d.compute_vertex_normals()
# 3. Show in Open3D viewer
o3d.visualization.draw_geometries([mesh_o3d])

In [10]:
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider, Dropdown
import ipywidgets as widgets



def get_sdf_slice(slice_val, slice_dim='Z', resolution=20, plot_range=3.0, transition_width=0.1):
    """
    Computes SDF values for a slice along X, Y, or Z.
    """
    # 1. Create a 2D grid for the viewing plane
    # U is the horizontal axis of the plot, V is the vertical axis
    grid_coords = torch.linspace(-plot_range, plot_range, resolution, device=device)
    U, V = torch.meshgrid(grid_coords, grid_coords, indexing='ij')
    
    # 2. Map U, V to 3D coordinates (X, Y, Z) based on slice_dim
    # We create a 'Fixed' tensor filled with the slice value
    Fixed = torch.full_like(U, slice_val)
    
    if slice_dim == 'Z':
        # Plane: XY (U=X, V=Y, Fixed=Z)
        query_points = torch.stack([U, V, Fixed], dim=-1)
    elif slice_dim == 'Y':
        # Plane: XZ (U=X, V=Z, Fixed=Y)
        query_points = torch.stack([U, Fixed, V], dim=-1)
    elif slice_dim == 'X':
        # Plane: YZ (U=Y, V=Z, Fixed=X)
        query_points = torch.stack([Fixed, U, V], dim=-1)
    else:
        raise ValueError("Invalid slice dimension")

    # Reshape for batch inference (1, N, 3)
    query_points = query_points.reshape(1, -1, 3)
    
    # 3. Run Inference
    with torch.no_grad():
        result = my_implicit_func(query_points)
        
    if device == 'mps': torch.mps.synchronize()
    elif device == 'cuda': torch.cuda.synchronize()
        
    # Reshape back to image format (Resolution, Resolution)
    return result.squeeze().cpu().numpy().reshape(resolution, resolution)

def interactive_plot(slice_pos=0.0, slice_dim='Z'):
    """
    Update function for the widget.
    """
    resolution = 150
    limit = 0.4
    plot_range = 2.0
    
    # Get data
    sdf_values = get_sdf_slice(slice_pos, slice_dim, resolution, plot_range)
    
    plt.figure(figsize=(7, 6))
    extent_bounds = (-plot_range, plot_range, -plot_range, plot_range)
    
    # 1. Heatmap
    im = plt.imshow(
        sdf_values,
        origin='lower',
        cmap='RdBu',
        vmin=-limit,
        vmax=limit,
        extent=extent_bounds
    )
    
    # 2. Zero Level Set (Thick Line)
    plt.contour(
        sdf_values,
        levels=[0],
        colors='black',
        linewidths=2,
        origin='lower',
        extent=extent_bounds
    )

    # 3. Background Contour Grid (Thin Lines)
    # Reduced range/density slightly for performance, adjust as needed
    plt.contour(
        sdf_values,
        levels=np.arange(-5, 5, 0.1), 
        colors='black',
        linewidths=0.3, # Made thinner for readability
        alpha=0.5,
        origin='lower',
        extent=extent_bounds
    )
    
    plt.colorbar(im, label='SDF Value')
    plt.title(f'Slice along {slice_dim}-axis at {slice_pos:.2f}')
    
    # Dynamic Axis Labels
    if slice_dim == 'Z':
        plt.xlabel('X axis')
        plt.ylabel('Y axis')
    elif slice_dim == 'Y':
        plt.xlabel('X axis')
        plt.ylabel('Z axis')
    elif slice_dim == 'X':
        plt.xlabel('Y axis')
        plt.ylabel('Z axis')

    plt.show()

# --- 3. Define Widgets ---

# Dropdown for selecting the Axis
axis_dropdown = Dropdown(
    options=['X', 'Y', 'Z'],
    value='Z',
    description='Axis:',
)

# Slider for the position along that axis
pos_slider = FloatSlider(
    value=0.0,
    min=-2.0,
    max=2.0,
    step=0.05,
    description='Slice Pos:',
    continuous_update=True # Set False if it lags
)


transition_width_slider = FloatSlider(
    value=0.6,
    min=0.0,
    max=10.0,
    step=0.1,
    description='Transition Width:',
    continuous_update=True
)

# --- 4. Launch Interface ---
interact(
    interactive_plot, 
    slice_pos=pos_slider, 
    slice_dim=axis_dropdown
);

interactive(children=(FloatSlider(value=0.0, description='Slice Pos:', max=2.0, min=-2.0, step=0.05), Dropdownâ€¦