In [1]:
from pathlib import Path

import torch
from torch import nn
from torch import optim
from torch.nn import MSELoss
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import trimesh
from mesh_to_sdf import sample_sdf_near_surface
import pyrender
from tqdm import tqdm, trange

In [2]:
mesh = trimesh.load('data/bunny.obj')
mesh.show()



In [3]:
_ = np.random.seed(42)
_ = torch.manual_seed(42)

In [4]:
points, sdf = sample_sdf_near_surface(mesh, number_of_points=3000000)

colors = np.zeros(points.shape)
colors[sdf < 0, 2] = 1
colors[sdf > 0, 0] = 1
cloud = pyrender.Mesh.from_points(points, colors=colors)
scene = pyrender.Scene()
scene.add(cloud)
viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)

In [5]:
model = nn.Sequential(
    nn.Linear(3, 32),
    nn.LeakyReLU(0.1),
    nn.Linear(32, 32),
    nn.LeakyReLU(0.1),
    nn.Linear(32, 32),
    nn.LeakyReLU(0.1),
    nn.Linear(32, 32),
    nn.LeakyReLU(0.1),
    nn.Linear(32, 1),
    nn.Tanh(),
)

In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)


model = model.apply(init_weights)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [8]:
device

'cuda'

In [9]:
optimizer = optim.Adam(model.parameters(), lr=1e-2)
criterion = nn.L1Loss()

In [10]:
n_epochs = 100

In [11]:
tensor_x = torch.Tensor(points)
tensor_y = torch.Tensor(sdf)

dataset = TensorDataset(tensor_x, tensor_y[:, None])
dataloader = DataLoader(
    dataset,
    batch_size=2048,
    shuffle=True,
)

In [12]:
pbar = trange(n_epochs)
for epoch_idx in pbar:
    mean_loss = 0.0
    for batch in dataloader:
        x, y = batch
        x, y = x.to(device), y.to(device)
        res = model(x)
        optimizer.zero_grad()
        loss = criterion(res, y)
        loss.backward()
        optimizer.step()
        mean_loss += loss.item()
    pbar.set_description(str(mean_loss / len(dataloader)))
    if epoch_idx and epoch_idx % 10 == 0:
        for g in optimizer.param_groups:
            g['lr'] /= 2

0.0011391115714176065: 100%|███████████████████████████████████████████████████████████████████████████████████████| 100/100 [29:05<00:00, 17.45s/it]


In [14]:
points = np.random.uniform(-1, 1, size=(3000000, 3)).astype(np.float32)
points = torch.from_numpy(points).to(device)

In [15]:
sdf_pred = model(points)[:, 0]

In [16]:
sdf_pred = sdf_pred.detach().cpu().numpy()
points = points.cpu().numpy()

In [17]:
points = points[sdf_pred <= 0]

In [23]:
colors = np.zeros(points.shape)
colors[..., 2] = 1
cloud = pyrender.Mesh.from_points(points, colors=colors)
scene = pyrender.Scene()
scene.add(cloud)
viewer = pyrender.Viewer(scene, use_raymond_lighting=True, point_size=2)

In [19]:
models_dir = Path('data') / 'models'
models_dir.mkdir(exist_ok=True)

In [20]:
x = torch.randn(1, 3, requires_grad=True)
torch.onnx.export(
    model.cpu(),
    x,
    models_dir / 'sdf.onnx',
    export_params=True,
    opset_version=10,
    do_constant_folding=True,
    input_names = ['input'],
    output_names = ['output'],
    dynamic_axes={
        'input' : {0 : 'batch_size'},
        'output' : {0 : 'batch_size'},
    },
)

In [22]:
torch.save(
    model.state_dict(),
    models_dir / 'sdf.pth'
)

In [24]:
traced_script_module = torch.jit.trace(model, x)

In [25]:
traced_script_module.save(models_dir / 'traced_sdf.pt')