# Template Mesh Fitting Demo

For Tesla virtual onsite interview presentation.

The notebook will walk through the process of fitting a template mesh onto a target mesh.

In [24]:
import json
from pathlib import Path

from star.pytorch.star import STAR
from star.config import set_model_path
import vedo
import trimesh
import numpy as np
import torch

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
set_model_path('star_1_1')
vedo.embedWindow('k3d')

In [26]:
def tmesh(mesh, c='grey', alpha=0.5):
    """convenience function for transparent mesh"""
    return vedo.Mesh(mesh, c=c, alpha=alpha)

In [27]:
target_mesh = trimesh.load_mesh('ORPH_20200122_2354_healed.obj')

In [28]:
num_betas = 20
star = STAR(gender='neutral', num_betas=num_betas)

A_pose_np = np.zeros((1, 72))
A_pose_np[0, 3*1+2] = 0.2
A_pose_np[0, 3*2+2] = -0.2
A_pose_np[0, 3*13+2] = -0.2
A_pose_np[0, 3*14+2] = 0.2
A_pose_np[0, 3*16+2] = -0.8
A_pose_np[0, 3*17+2] = 0.8
A_pose = torch.tensor(A_pose_np)

betas = torch.tensor(np.zeros((1, num_betas)))
trans = torch.tensor(np.zeros((1, 3)))

template_verts = np.roll(star.forward(A_pose, betas, trans)[0], 1, axis=1) * 1000
non_aligned_template_mesh = trimesh.Trimesh(template_verts, star.f, process=False)

# centroid align
trans_align_np = np.roll(
    target_mesh.vertices.mean(axis=0) - non_aligned_template_mesh.vertices.mean(axis=0), 
    -1
) / 1000
trans_align = torch.tensor(trans_align_np[None, :])
align_template_verts = np.roll(star.forward(A_pose, betas, trans_align)[0], 1, axis=1) * 1000
template_mesh = trimesh.Trimesh(align_template_verts, star.f, process=False)

In [29]:
with open('vertex_groups.json') as fin:
    vertex_groups = json.load(fin)
exclude_idx = [i for name in ('head', 'hand_left', 'hand_right', 'toes') for i in vertex_groups[name]]
inc_verts = np.ones(template_mesh.vertices.shape[0], dtype=bool)
inc_verts[exclude_idx] = False

In [30]:
def find_heels(mesh):
    v = mesh.vertices
    mid_y = (v[:, 1].max()+v[:, 1].min())/2
    left_idx, = (v[:, 1] > mid_y).nonzero()
    left_heel_idx = left_idx[(v[left_idx, 0]+v[left_idx, 2]).argmin()]
    right_idx, = (v[:, 1] < mid_y).nonzero()
    right_heel_idx = right_idx[(v[right_idx, 0]+v[right_idx, 2]).argmin()]
    
    return left_heel_idx, right_heel_idx

# find template heels
template_left_heel_idx, template_right_heel_idx = find_heels(template_mesh)

# find target heels
target_left_heel, target_right_heel = target_mesh.vertices[find_heels(target_mesh), :]

landmarks = [(template_left_heel_idx, target_left_heel), (template_right_heel_idx, target_right_heel)]

Plot below shows the template mesh. The vertices that are excluded from fitting are highlighted in red. Heels are highlighted in blue. 

In [31]:
vedo.show(
    template_mesh,
    vedo.Spheres(template_mesh.vertices[exclude_idx], r=10),
    vedo.Spheres(template_mesh.vertices[[template_left_heel_idx, template_right_heel_idx]], r=20, c='blue'),
    viewup='z',
    axes=1
)

Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, background_color=16777215, camera=[1140.0429027815971…

In [32]:
def arrow_plot(template_mesh, target_mesh, inc_verts):
    template_verts = template_mesh.vertices[inc_verts]
    dists, idxs = target_mesh.kdtree.query(template_verts)
    arrows = vedo.Arrows(template_verts, target_mesh.vertices[idxs])
    return vedo.show(
        tmesh(template_mesh),
        tmesh(target_mesh), 
        arrows,
        viewup='z'
    )

### Diff plot: before fit

The arrows point from a template vertex to the closest target vertex, which are the pairs used to calculate data loss. The optimizer will try to deform the vertices in the arrows direction, thus matching the shapes.

In [33]:
arrow_plot(template_mesh, target_mesh, inc_verts)

Plot(antialias=3, axes=['x', 'y', 'z'], axes_helper=1.0, background_color=16777215, camera=[1120.2648393608604…

### Define the loss functions

We have

- data loss: the ICP sum of squared distances between corresponding vertices
- marking loss: sum of squared distances between markings
- regularization loss: try to keep the shape parameters close to mean shape so things don't go crazy

In [34]:
"""define loss functions"""
def calc_loss(
    deformed_v: torch.Tensor, 
    betas: torch.Tensor, 
    exc_verts=[], min_comp_cosine=0, landmarks=[], E_weights=(1, 1, 0)
):
    E_data = data_loss(deformed_v, target_mesh, exc_verts, min_comp_cosine)
    E_lm = landmark_loss(deformed_v, landmarks)
    E_reg = shape_reg_loss(betas)
#     breakpoint()
    
    E_total = E_weights[0]*E_data + E_weights[1]*E_lm + E_weights[2]*E_reg
    
    return E_total

def data_loss(deformed_v : torch.Tensor, target_mesh, exc_verts, min_comp_cosine):
    deformed_mesh = trimesh.Trimesh(deformed_v.detach().numpy(), star.f, process=False)
    dists, idxs = target_mesh.kdtree.query(deformed_mesh.vertices)
    deformed_normals = deformed_mesh.vertex_normals
    target_normals = target_mesh.vertex_normals
    incompatible = (target_normals[idxs]*deformed_normals).sum(axis=1) < min_comp_cosine
    weights = np.ones(template_mesh.vertices.shape[0])
    weights[incompatible] = 0
    weights[exc_verts] = 0

    dist_tensor = (
        (
            (
                deformed_v-torch.tensor(target_mesh.vertices[idxs], requires_grad=False)
            ) * torch.tensor(weights[:, None], requires_grad=False)
        )**2
    ).sum() / weights.sum()
    return dist_tensor

def landmark_loss(deformed_v: torch.Tensor, landmarks):
    loss = torch.tensor(0.)
    for template_idx, target in landmarks:
        loss += ((deformed_v[template_idx] - torch.tensor(target))**2).sum()
    if len(landmarks) > 0:
        loss /= len(landmarks)
    return loss

def shape_reg_loss(betas: torch.Tensor):
    return (betas**2).mean()

## Perform the fitting!

Using PyTorch's `LBFGS` built-in optimizer for convenience. 

In [35]:
from torch.optim import LBFGS

poses = torch.tensor(A_pose_np, requires_grad=True)
betas = torch.tensor(np.random.normal(0, 0, (1, num_betas)), requires_grad=True)
trans = torch.tensor(trans_align_np[None, :], requires_grad=True)

optimizer = LBFGS([poses, betas, trans], lr=1, max_iter=100)

def simple_loss(poses, betas, trans, min_comp_cosine, E_weight):
    forwarded_v = star.forward(poses, betas, trans)[0]
    forwarded_v = torch.roll(forwarded_v, 1, 1) * 1000
    return calc_loss(forwarded_v, betas, exclude_idx, min_comp_cosine, landmarks, E_weights)

def constrain_poses(poses):
    """constrain some poses to not change, which maintain left-right symmetry"""
    with torch.no_grad():
        poses.grad[0, 2] = 0
        poses.grad[0, 3*12+1] = 0
        poses.grad[0, 3*12+2] = 0
        poses.grad[0, 3*15+1] = 0
        poses.grad[0, 3*15+2] = 0
        
def closure():
    optimizer.zero_grad()
    loss = simple_loss(poses, betas, trans, min_comp_cosine=0, E_weights=(10, 1, 0))
    loss.backward()
    constrain_poses(poses)
    return loss
loss = optimizer.step(closure)
    
new_v = star.forward(poses, betas, trans).detach().numpy()[0]
new_v = np.roll(new_v, 1, axis=1) * 1000
fitted_mesh = trimesh.Trimesh(new_v, star.f, process=False)

TypeError: simple_loss() got an unexpected keyword argument 'E_weights'

In [None]:
arrow_plot(fitted_mesh, target_mesh, inc_verts)

## Landmarking

In [None]:
with open('template_landmark_idxs.json') as fin:
    template_lm_idxs = json.load(fin)
    ""
names_to_use = template_lm_idxs.keys()
idxs_to_use = [template_lm_idxs[n] for n in names_to_use]
template_lm_p = fitted_mesh.vertices[idxs_to_use]
found_idx = target_mesh.kdtree.query(template_lm_p)[1]
extracted_landmarks = {n: p for n, p in zip(names_to_use, target_mesh.vertices[found_idx])}

In [None]:
from pprint import pprint

print('Landmark template vertex indices:')
pprint(template_lm_idxs)

In [None]:
vedo.show(
    tmesh(target_mesh),
    vedo.Spheres(list(extracted_landmarks.values()), r=10),
    viewup='z',
    axes=1
)