In [1]:
from collections import namedtuple
from pathlib import Path

import torch
import trimesh
import numpy as np
import pyglet
from trimesh import viewer

In [2]:
class_labels = dict(
    head=0,
    torso=1,
    left_arm=2,
    left_hand=3,
    right_arm=4,
    right_hand=5,
    left_upper_leg=6,
    left_lower_leg=7,
    left_foot=8,
    right_upper_leg=9,
    right_lower_leg=10,
    right_foot=11,
)

class_colors = dict(
    head=np.array([255, 0, 0, 255]),
    torso=np.array([255, 0, 255, 255]),
    left_arm=np.array([255, 255, 0, 255]),
    left_hand=np.array([255, 128, 0, 255]), 
    right_arm=np.array([0, 255, 0, 255]),
    right_hand=np.array([0, 255, 128, 255]),
    left_upper_leg=np.array([0, 128, 255, 255]),
    left_lower_leg=np.array([0, 255, 255, 255]),
    left_foot=np.array([0, 0, 255, 255]),
    right_upper_leg=np.array([128, 0, 255, 255]),
    right_lower_leg=np.array([128, 255, 0, 255]),
    right_foot=np.array([255, 0, 128, 255])
)

In [3]:
BodyPart = namedtuple("body_part", ["label", "mesh"])

In [4]:
complete_mesh_fname = Path("../datasets/faust/MPI-FAUST/training/registrations/tr_reg_000.ply")
complete_mesh = trimesh.load_mesh(complete_mesh_fname, process=False)

In [5]:
body_parts_dir = Path("../datasets/faust/MPI-FAUST/semantic_labels/")
body_parts = [
    BodyPart(mesh_fname.stem, trimesh.load_mesh(mesh_fname, process=False))
    for mesh_fname in body_parts_dir.glob("*.ply")
]

In [6]:
def find_vertices_on_mesh(vertices, mesh):
    def _find_vertex_on_mesh(_vertex):
        try:
            return np.where(np.all(mesh.vertices == _vertex, axis=1))[0][0]
        except IndexError:
            return None
    
    mesh_vertex_idx = []
    for vertex in vertices:
        idx = _find_vertex_on_mesh(vertex)
        if idx is not None:
            mesh_vertex_idx.append(idx)

    return mesh_vertex_idx

In [7]:
def _subtract_faces(vertex_idx_to_subtract, faces):
    def _to_subtract(face):
        return any(vertex_idx in vertex_idx_to_subtract for vertex_idx in face)

    faces_to_keep = [
        not _to_subtract(face)
        for face in faces
    ]
    return faces[faces_to_keep, :]


def remap_faces(faces_to_keep, remapping):
    return np.vectorize(remapping.__getitem__)(faces_to_keep)

def subtract_submesh(mesh, sub_mesh):
    idx_to_subtract = set(find_vertices_on_mesh(sub_mesh.vertices, mesh))
    idx_complete_mesh = set(range(len(mesh.vertices)))
    non_submesh_vertex_idx = idx_complete_mesh.difference(idx_to_subtract)

    remapping = dict(zip(non_submesh_vertex_idx, range(len(non_submesh_vertex_idx))))
    faces_to_keep = _subtract_faces(idx_to_subtract, np.asarray(mesh.faces))
    subtracted_mesh_faces = remap_faces(faces_to_keep, remapping)
    return trimesh.base.Trimesh(
        vertices=mesh.vertices[list(non_submesh_vertex_idx), :],
        faces=subtracted_mesh_faces,
        process=False,
    )

In [8]:
# subtracted_mesh = complete_mesh
# for _, sub_mesh in body_part_meshes:
#     subtracted_mesh = subtract_submesh(subtracted_mesh, sub_mesh)

In [39]:
test_mesh_fname = Path("../datasets/faust/MPI-FAUST/training/registrations/tr_reg_060.ply")
test_mesh = trimesh.load_mesh(test_mesh_fname, process=False)

test_mesh_fname_2 = Path("../datasets/faust/MPI-FAUST/training/registrations/tr_reg_070.ply")
test_mesh_2 = trimesh.load_mesh(test_mesh_fname_2, process=False)

In [46]:
for body_part in body_parts:
    body_part_idx = find_vertices_on_mesh(
        body_part.mesh.vertices,
        complete_mesh,
    )
    body_part_color = class_colors[body_part.label]
    test_mesh_2.visual.vertex_colors[body_part_idx] = body_part_color

In [40]:
test_mesh.vertices += np.array([1.4, 0, 0])


num_matches = 200
matches = np.random.choice(6890, num_matches)

segments = np.zeros((num_matches, 2, 3))
for idx, match_idx in enumerate(matches):
    segments[idx] = np.vstack([test_mesh.vertices[match_idx].reshape(1, -1), test_mesh_2.vertices[match_idx].reshape(1, -1)])

correspondences = trimesh.load_path(segments)
correspondences.colors = np.repeat([[0, 0, 255, 0]], num_matches, axis=0)

In [41]:
scene = trimesh.Scene([test_mesh, test_mesh_2, correspondences])
scene.show()

In [47]:
test_mesh_2.show()

In [None]:
def generate_segmentation_labels(full_body_mesh, body_parts, class_labels):
    total_num_vertices = full_body_mesh.vertices.shape[0]
    semantic_labels = torch.full((total_num_vertices, 1), fill_value=-1, dtype=torch.int32)
    
    for body_part in body_parts:
        full_mesh_vertex_idx = find_vertices_on_mesh(body_part.mesh.vertices, full_body_mesh)
        semantic_class_idx = class_labels[body_part.label]
        semantic_labels[full_mesh_vertex_idx] = torch.tensor([semantic_class_idx], dtype=torch.int32)
    return semantic_labels

In [None]:
segmentation_labels = generate_segmentation_labels(complete_mesh, body_parts, class_labels)

In [None]:
# np.savez("../datasets/faust/semantic_labels/segementations.npz", segmentation_labels=segmentation_labels.numpy())

In [None]:
# test = np.load("../datasets/faust/semantic_labels/segementations.npz")

In [None]:
# np.all(test["segmentation_labels"] == segmentation_labels.numpy())

In [None]:
def color_mesh(mesh: trimesh.base.Trimesh, color_map):
    mesh.visual.vertex_colors = color_map
    return mesh

In [None]:
map_seg_id_to_color = dict(
    (_value, class_colors[_key])
    for _key, _value in class_labels.items()
)

In [None]:
def map_seg_label_to_color(seg_ids, map_seg_id_to_color):
    return np.vstack(
        [map_seg_id_to_color[int(seg_ids[idx])] for idx in range(seg_ids.shape[0])]
    )

In [None]:
def visualize_predictions(predictions, data, map_seg_id_to_color):
    predicted_seg_ids = torch.argmax(predictions, dim=-1)
    mesh = trimesh.base.Trimesh(
        vertices=data.x.numpy().T,
        faces=data.face.numpy(),
        process=False,
    )
    color_map = map_seg_label_to_color(predicted_seg_ids, map_seg_id_to_color)
    color_mesh(mesh, color_map)

In [None]:
color_map = map_seg_label_to_color(segmentation_labels, map_seg_id_to_color)
color_mesh(test_mesh, color_map).show()

In [None]:
np.vstack([map_class_id_to_color[segmentation_labels.numpy()[idx, 0]] for idx in range(segmentation_labels.shape[0])])

In [None]:
map_class_id_to_color[int(segmentation_labels[5])]

In [None]:
def accuracy(predictions, gt_segmentation_labels):
    """Returns the mean accuracy of a model's predictions on a set of examples.

    Args:
        logits (torch.Tensor): model predicted logits
            shape (examples, classes)
        labels (torch.Tensor): classification labels from 0 to num_classes - 1
            shape (examples,)
    """

    assert predictions.dim() == 2
    assert labels.dim() == 1
    assert predictions.shape[0] == gt_segmentation_labels.shape[0]
    correct_assignments = torch.argmax(predictions, dim=-1) == gt_segmentation_labels
    correct_assignments = correct_assignments.type(torch.float)
    return torch.mean(correct_assignments).item()