<a href="https://colab.research.google.com/github/yongsun-yoon/deep-learning-paper-implementation/blob/main/02-computer-vision/DMTet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# DMTet

## 0. Info

### Paper
* title: Deep Marching Tetrahedra: a Hybrid Representation for High-Resolution 3D Shape Synthesis
* author: Tianchang Shen et al.
* url: https://arxiv.org/abs/2111.04276

### Features
* dataset: 

### Reference
* https://github.com/NVIDIAGameWorks/kaolin/blob/master/examples/tutorial/dmtet_tutorial.ipynb

## 1. Setup

In [1]:
!nvidia-smi

Tue Oct  4 21:58:56 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   54C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# install kaolin
import os
os.environ['IGNORE_TORCH_VER'] = '1'
os.environ['KAOLIN_INSTALL_EXPERIMENTAL'] = '1'

!git clone --recursive https://github.com/NVIDIAGameWorks/kaolin
%cd kaolin
!python setup.py develop
!pip install .
!pip install -q usd-core pyngrok einops

%cd /content
!rm -rf kaolin

# restart runtime

In [1]:
import easydict
import numpy as np
from tqdm.auto import tqdm
from pyngrok import ngrok
from einops import rearrange

import kaolin
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
cfg = easydict.EasyDict(
    device = 'cuda',
    num_max_points = 100000,
    multires = 2,
    lr = 1e-3,
    num_training_steps = 5000,
    laplacian_weight = 0.1,
)

## 2. Data

### 2.1. Prepare

In [3]:
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/bear_pointcloud.usd
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_verts.npz
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_tets_0.npz
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_tets_1.npz
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_tets_2.npz
!wget https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_tets_3.npz

--2022-10-04 22:09:41--  https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/bear_pointcloud.usd
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 3210920 (3.1M) [application/octet-stream]
Saving to: ‘bear_pointcloud.usd.1’


2022-10-04 22:09:41 (197 MB/s) - ‘bear_pointcloud.usd.1’ saved [3210920/3210920]

--2022-10-04 22:09:41--  https://raw.githubusercontent.com/NVIDIAGameWorks/kaolin/master/examples/samples/128_verts.npz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 610391 (596K) [application/octet-

### 2.2. Pointcloud

In [4]:
timelapse = kaolin.visualize.Timelapse('logs')

In [5]:
# load pointcloud
points = kaolin.io.usd.import_pointclouds('bear_pointcloud.usd')[0].points.to(cfg.device)

In [6]:
# sampling
num_points = points.size(0)
if num_points > cfg.num_max_points:
    rand_idxs = np.random.permutation(range(num_points))[:cfg.num_max_points]
    points = points[rand_idxs]

In [7]:
# normalize
max_val = points.max(dim=0).values
min_val = points.min(dim=0).values

center = (max_val + min_val) / 2
maxlen = (max_val - min_val).max()

points = (points - center) / maxlen * 0.9

In [8]:
points = points.unsqueeze(dim=0)

In [9]:
timelapse.add_pointcloud_batch(category='input', pointcloud_list=points.cpu(), points_type = "usd_geom_points")

### 2.3. Tetrahedral grid

In [10]:
tet_verts = torch.tensor(np.load('128_verts.npz')['data'], dtype=torch.float, device=cfg.device)

tets = np.stack([np.load(f'128_tets_{i}.npz')['data'] for i in range(4)], axis=1)
tets = torch.tensor(tets, dtype=torch.long, device=cfg.device)

tet_verts.size(), tets.size()

(torch.Size([277410, 3]), torch.Size([1524684, 4]))

## 3. Model

### 3.1. Marching Tetrahedra

In [29]:
# https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/conversions/tetmesh.py
triangle_table = torch.tensor([
    [-1, -1, -1, -1, -1, -1],
    [1, 0, 2, -1, -1, -1],
    [4, 0, 3, -1, -1, -1],
    [1, 4, 2, 1, 3, 4],
    [3, 1, 5, -1, -1, -1],
    [2, 3, 0, 2, 5, 3],
    [1, 4, 0, 1, 5, 4],
    [4, 2, 5, -1, -1, -1],
    [4, 5, 2, -1, -1, -1],
    [4, 1, 0, 4, 5, 1],
    [3, 2, 0, 3, 5, 2],
    [1, 3, 5, -1, -1, -1],
    [4, 1, 2, 4, 3, 1],
    [3, 0, 4, -1, -1, -1],
    [2, 0, 1, -1, -1, -1],
    [-1, -1, -1, -1, -1, -1]
], dtype=torch.long)

num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long)
base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long)
v_id = torch.pow(2, torch.arange(4, dtype=torch.long))


def _sort_edges(edges):
    with torch.no_grad():
        order = (edges[:, 0] > edges[:, 1]).long()
        order = order.unsqueeze(dim=1)

        a = torch.gather(input=edges, index=order, dim=1)
        b = torch.gather(input=edges, index=1 - order, dim=1)

    return torch.stack([a, b], -1)


def _unbatched_marching_tetrahedra(vertices, tets, sdf, return_tet_idx):
    device = vertices.device
    with torch.no_grad():
        occ_n = sdf > 0
        occ_fx4 = occ_n[tets.reshape(-1)].reshape(-1, 4)
        occ_sum = torch.sum(occ_fx4, -1)
        valid_tets = (occ_sum > 0) & (occ_sum < 4)
        occ_sum = occ_sum[valid_tets]

        # find all vertices
        all_edges = tets[valid_tets][:, base_tet_edges.to(device)].reshape(-1, 2)
        all_edges = _sort_edges(all_edges)
        unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True)

        unique_edges = unique_edges.long()
        mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
        mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=device) * -1
        mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=device)
        idx_map = mapping[idx_map]

        interp_v = unique_edges[mask_edges]
    edges_to_interp = vertices[interp_v.reshape(-1)].reshape(-1, 2, 3)
    edges_to_interp_sdf = sdf[interp_v.reshape(-1)].reshape(-1, 2, 1)
    edges_to_interp_sdf[:, -1] *= -1

    denominator = edges_to_interp_sdf.sum(1, keepdim=True)

    edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator
    verts = (edges_to_interp * edges_to_interp_sdf).sum(1)

    idx_map = idx_map.reshape(-1, 6)

    tetindex = (occ_fx4[valid_tets] * v_id.to(device).unsqueeze(0)).sum(-1)
    num_triangles = num_triangles_table.to(device)[tetindex]
    triangle_table_device = triangle_table.to(device)

    # Generate triangle indices
    faces = torch.cat((
        torch.gather(input=idx_map[num_triangles == 1], dim=1,
                     index=triangle_table_device[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3),
        torch.gather(input=idx_map[num_triangles == 2], dim=1,
                     index=triangle_table_device[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3),
    ), dim=0)

    if return_tet_idx:
        tet_idx = torch.arange(tets.shape[0], device=device)[valid_tets]
        tet_idx = torch.cat((tet_idx[num_triangles == 1], tet_idx[num_triangles ==
                            2].unsqueeze(-1).expand(-1, 2).reshape(-1)), dim=0)
        return verts, faces, tet_idx
    return verts, faces


def marching_tetrahedra(vertices, tets, sdf, return_tet_idx=False):
    list_of_outputs = [_unbatched_marching_tetrahedra(vertices[b], tets, sdf[b], return_tet_idx) for b in range(vertices.shape[0])]
    return list(zip(*list_of_outputs))

### 3.2. Decoder

In [12]:
class Embedding(object):
    def __init__(
        self, 
        include_inputs = True,
        input_dim = 3,
        max_freq_log2 = 1,
        num_freqs = 2,
        log_sampling = True,
        periodic_fns = [torch.sin, torch.cos]
    ):

        self.output_dim = 0
        self.fns = []
        if include_inputs:
            self.fns.append(lambda x: x)
            self.output_dim += input_dim
        
        if log_sampling:
            freq_bands = 2. ** torch.linspace(0., max_freq_log2, steps=num_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq_log2, steps=num_freqs)
        
        for freq in freq_bands:
            for pfn in periodic_fns:
                self.fns.append(lambda x: pfn(x * freq))
                self.output_dim += input_dim
        

    def __call__(self, x):
        return torch.cat([fn(x) for fn in self.fns], dim=-1)


class Decoder(torch.nn.Module):
    def __init__(
        self,
        input_dim = 3, 
        hidden_dim = 128, 
        output_dim = 4, 
        num_layers = 5, 
        multires = 2
    ):
        super().__init__()
        self.embedding = Embedding(
            input_dim = input_dim,
            max_freq_log2 = multires-1,
            num_freqs = multires,
        )

        net = [nn.Linear(self.embedding.output_dim, hidden_dim, bias=False), nn.ReLU()]
        for _ in range(num_layers - 1):
            net += [nn.Linear(hidden_dim, hidden_dim, bias=False), nn.ReLU()]
        net += [nn.Linear(hidden_dim, output_dim, bias=False)]
        self.net = torch.nn.Sequential(*net)


    def forward(self, x):
        x = self.embedding(x)
        out = self.net(x)
        return out


    def pretrain_sphere(self, num_pretraining_steps=1000):
        device = next(model.parameters()).device
        optimizer = torch.optim.Adam(list(self.parameters()), lr=1e-4)

        pbar = tqdm(range(num_pretraining_steps))
        for _ in pbar:
            p = torch.rand((1024, 3), device=device) - 0.5
            y = torch.sqrt((p**2).sum(-1)) - 0.3
            output = self(p)
            loss = F.mse_loss(output[...,0], y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            pbar.set_postfix({'loss': loss.item()})

        print("Finished pretraining")

### 3.3. Loss

In [13]:
def laplace_regularizer_const(mesh_verts, mesh_faces):
    term = torch.zeros_like(mesh_verts)
    norm = torch.zeros_like(mesh_verts[..., 0:1])

    v0 = mesh_verts[mesh_faces[:, 0], :]
    v1 = mesh_verts[mesh_faces[:, 1], :]
    v2 = mesh_verts[mesh_faces[:, 2], :]

    term.scatter_add_(0, mesh_faces[:, 0:1].repeat(1,3), (v1 - v0) + (v2 - v0))
    term.scatter_add_(0, mesh_faces[:, 1:2].repeat(1,3), (v0 - v1) + (v2 - v1))
    term.scatter_add_(0, mesh_faces[:, 2:3].repeat(1,3), (v0 - v2) + (v1 - v2))

    two = torch.ones_like(v0) * 2.0
    norm.scatter_add_(0, mesh_faces[:, 0:1], two)
    norm.scatter_add_(0, mesh_faces[:, 1:2], two)
    norm.scatter_add_(0, mesh_faces[:, 2:3], two)

    term = term / torch.clamp(norm, min=1.0)
    return torch.mean(term**2)

def loss_fn(mesh_verts, mesh_faces, points, st, cfg):
    pred_points = kaolin.ops.mesh.sample_points(mesh_verts.unsqueeze(0), mesh_faces, 50000)[0]
    loss = kaolin.metrics.pointcloud.chamfer_distance(pred_points, points).mean()
    if st > cfg.num_training_steps // 2:
        lap = laplace_regularizer_const(mesh_verts, mesh_faces)
        loss += lap * cfg.laplacian_weight
    return loss

## 4. Train

In [30]:
model = Decoder().to(cfg.device)

In [31]:
model.pretrain_sphere(1000)

  0%|          | 0/1000 [00:00<?, ?it/s]

Finished pretraining


In [32]:
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.lr)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda x: max(0.0, 10**(-x*0.0002))) # LR decay over time

In [34]:
pbar = tqdm(range(1, cfg.num_training_steps+1))
for st in pbar:
    preds = model(tet_verts) # predict SDF and per-vertex deformation
    sdf, deform = preds[:, 0], preds[:, 1:]
    verts_deformed = tet_verts + torch.tanh(deform) / 128 # constraint deformation to avoid flipping tets
    mesh_verts, mesh_faces = marching_tetrahedra(verts_deformed.unsqueeze(0), tets, sdf.unsqueeze(0)) # running MT (batched) to extract surface mesh
    mesh_verts, mesh_faces = mesh_verts[0], mesh_faces[0]
    loss = loss_fn(mesh_verts, mesh_faces, points, st, cfg)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    scheduler.step()
    pbar.set_postfix({'loss': loss.item()})
    
    if (st) % 100 == 0: 
        print(f'Step {st:04d} | loss {loss:.6f} | mesh vertices {mesh_verts.shape[0]:06d} | mesh faces {mesh_faces.shape[0]:06d}')
        # save reconstructed mesh
        timelapse.add_mesh_batch(
            iteration=st,
            category='extracted_mesh',
            vertices_list=[mesh_verts.cpu()],
            faces_list=[mesh_faces.cpu()]
        )

  0%|          | 0/5000 [00:00<?, ?it/s]

Step 0100 | loss 0.004 | mesh vertices 023354 | mesh faces 046704
Step 0200 | loss 0.008 | mesh vertices 028916 | mesh faces 057828
Step 0300 | loss 0.002 | mesh vertices 025290 | mesh faces 050576
Step 0400 | loss 0.003 | mesh vertices 023204 | mesh faces 046404
Step 0500 | loss 0.002 | mesh vertices 027368 | mesh faces 054732
Step 0600 | loss 0.001 | mesh vertices 027898 | mesh faces 055792
Step 0700 | loss 0.001 | mesh vertices 029706 | mesh faces 059420
Step 0800 | loss 0.001 | mesh vertices 030172 | mesh faces 060356
Step 0900 | loss 0.001 | mesh vertices 024530 | mesh faces 049052
Step 1000 | loss 0.000 | mesh vertices 029516 | mesh faces 059032
Step 1100 | loss 0.000 | mesh vertices 028488 | mesh faces 056972
Step 1200 | loss 0.000 | mesh vertices 028570 | mesh faces 057140
Step 1300 | loss 0.000 | mesh vertices 027982 | mesh faces 055960
Step 1400 | loss 0.000 | mesh vertices 028248 | mesh faces 056492
Step 1500 | loss 0.000 | mesh vertices 028750 | mesh faces 057496
Step 1600 

## 5. Test

In [26]:
ngrok.set_auth_token("###")



In [27]:
!nohup kaolin-dash3d --logdir=logs --port=80 &

nohup: appending output to 'nohup.out'


In [28]:
ngrok.connect(port="80")

<NgrokTunnel: "http://0939-34-86-244-124.ngrok.io" -> "http://localhost:80">

In [35]:
ngrok.kill()