In [1]:
import numpy as np
import trimesh
from skimage.measure import marching_cubes
from lib.models.deepsdf import DeepSDF
import torch

def remove_faces_outside_sphere(mesh: trimesh.Trimesh, sphere=1.03):
    vertex_norms = np.linalg.norm(mesh.vertices, axis=1)
    indices_to_remove = np.where(vertex_norms > sphere)[0]

    mask = np.zeros(mesh.faces.shape[0], bool)
    for i, face in enumerate(mesh.faces):
        for vertex in face:
            if vertex in indices_to_remove:
                mask[i] = 1
                break

    mesh.faces = mesh.faces[~mask]
    mesh.remove_unreferenced_vertices()
    return mesh


model = DeepSDF.load_from_checkpoint("/Users/robinborth/Code/sketch2shape/logs/last.ckpt")
_ = model.eval()

/opt/homebrew/Caskroom/miniforge/base/envs/sketch2shape/lib/python3.9/site-packages/lightning/pytorch/utilities/migration/utils.py:55: The loaded checkpoint was produced with Lightning v2.1.2, which is newer than your current Lightning version: v2.1.0
/opt/homebrew/Caskroom/miniforge/base/envs/sketch2shape/lib/python3.9/site-packages/lightning/pytorch/utilities/parsing.py:198: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.


In [2]:
resolution = 128
idx = 34
grid_vals = torch.arange(-1, 1, float(2 / resolution))
grid = torch.meshgrid(grid_vals, grid_vals, grid_vals)
xyz = torch.stack((grid[0].ravel(), grid[1].ravel(), grid[2].ravel())).transpose(1, 0)
lat_vec_idx = torch.full((xyz.shape[0],), idx)
lat_vec = model.lat_vecs(lat_vec_idx)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
chunck_size = 50000
n_chunks = (xyz.shape[0] // chunck_size) + 1
xyz_chunks = xyz.unsqueeze(0).chunk(n_chunks, dim=1)
lat_vec_chunks = lat_vec.unsqueeze(0).chunk(n_chunks, dim=1)
sd_list = list()
for _xyz, _lat_vec in zip(xyz_chunks, lat_vec_chunks):
    sd = model.predict((_xyz, _lat_vec)).squeeze()
    sd_list.append(sd)
sd = np.concatenate(sd_list)
sd_r = sd.reshape(resolution, resolution, resolution)

In [4]:
verts, faces, _, _ = marching_cubes(sd_r, level=0.0)
x_max = np.array([1, 1, 1])
x_min = np.array([-1, -1, -1])
verts = verts * ((x_max - x_min) / (resolution)) + x_min

# Create a trimesh object
mesh = trimesh.Trimesh(vertices=verts, faces=faces)
mesh.show()