In [1]:
from pathlib import Path

import torch
from torch import nn
from torch import optim
from torch.nn import MSELoss
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import trimesh
from mesh_to_sdf import sample_sdf_near_surface, scale_to_unit_sphere
import pyrender
from tqdm import tqdm, trange

In [2]:
mesh = trimesh.load('data/bunny.obj')
mesh = scale_to_unit_sphere(mesh)
mesh.show()



In [3]:
mesh.bounds

array([[-0.73137298, -0.73454537, -0.5585681 ],
       [ 0.73137298,  0.73454537,  0.5585681 ]])

In [4]:
_ = np.random.seed(42)
_ = torch.manual_seed(42)

In [5]:
width = 32

model = nn.Sequential(
    nn.Linear(3, width),
    nn.LeakyReLU(0.1),
    nn.Linear(width, width),
    nn.LeakyReLU(0.1),
    nn.Linear(width, width),
    nn.LeakyReLU(0.1),
    nn.Linear(width, width),
    nn.LeakyReLU(0.1),
    nn.Linear(width, 1),
    nn.Tanh(),
)

In [6]:
def init_weights(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight)
        m.bias.data.fill_(0.01)


model = model.apply(init_weights)
model = model.train()

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [8]:
device

'cuda'

In [9]:
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [10]:
n_epochs = 100

In [11]:
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    'min',
    factor=0.5,
    patience=10,
    min_lr=1e-7,
)

In [12]:
points, sdf, grad = sample_sdf_near_surface(
    mesh,
    number_of_points=2000000,
    return_gradients=True
)

In [13]:
tensor_x = torch.Tensor(points)
tensor_y = torch.Tensor(sdf)
tensor_grad = torch.Tensor(grad)

dataset = TensorDataset(tensor_x, tensor_y[:, None], tensor_grad)
dataloader = DataLoader(
    dataset,
    batch_size=4096,
    shuffle=False,
)

In [14]:
def train(
        model,
        dataloader,
        optimizer,
        scheduler,
        n_epochs,
        lambda_1):
    pbar = trange(n_epochs)
    l2_loss = MSELoss()
    for epoch_idx in pbar:
        mean_loss = 0.0
        for x, y, grad_tch in dataloader:
            x, y, grad_tch = x.to(device), y.to(device), grad_tch.to(device)

            # basic l2 loss
            sdf_approx = model(x)
            loss = l2_loss(sdf_approx, y)            

            # normal loss
            if lambda_1 > 0.0:
                res = []
                for i in range(3):
                    x1, x2 = torch.clone(x), torch.clone(x)
                    x1[:, i] += 1e-4
                    x2[:, i] -= 1e-4
                    res_x1 = model(x1)
                    res_x2 = model(x2)
                    res.append(res_x1 - res_x2)
                res = torch.cat(res, dim=1)
                grad_approx = torch.nn.functional.normalize(res, dim=1)
                loss += lambda_1 * l2_loss(grad_approx, grad_tch)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            mean_loss += loss.item()

        mean_loss /= len(dataloader)
        pbar.set_description(str(mean_loss))
        scheduler.step(mean_loss)

In [15]:
train(
    model,
    dataloader,
    optimizer,
    scheduler,
    n_epochs,
    0.0,
)

2.6944116043132956e-05: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [24:42<00:00, 14.83s/it]


In [17]:
train(
    model,
    dataloader,
    optimizer,
    scheduler,
    n_epochs,
    0.0,
)

1.695751350528001e-05: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [24:57<00:00, 14.98s/it]


In [23]:
train(
    model,
    dataloader,
    optimizer,
    scheduler,
    n_epochs,
    0.1,
)

0.001384290710857218: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [30:55<00:00, 18.56s/it]


In [27]:
train(
    model,
    dataloader,
    optimizer,
    scheduler,
    n_epochs,
    0.1,
)

0.0013807056641209955: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [30:04<00:00, 18.04s/it]


In [29]:
out_dir = Path('data') / 'NeuralNetworkWeightsCustom'
out_dir.mkdir(exist_ok=True)

state = model.state_dict()
for k, v in state.items():
    idx = int(k.split('.')[0])
    if k.endswith('weight'):
        k = 'weights' + str(idx // 2) + '.txt'
    else:
        k = 'biases' + str(idx // 2) + '.txt'

    np.savetxt(
        out_dir/ k,
        v.cpu().numpy(),
        fmt='%.12f',
        delimiter=', ',
        newline=', ',
    )