# FLAME Projection 
Test that the FLAME model projection is correct.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import torch

if 'has_been_executed' not in locals():
    print("Changing directory to the root of the project")
    os.chdir("../../../")
    has_been_executed = True
else:
    print("The directory is correct")

In [None]:
# Set up
from thesis.data_management import SequenceManager, UnbatchedFlameParams, FlameParams
from thesis.flame import FlameHead

sm = SequenceManager(3, cameras=[0, 8, 15])
flame_head = FlameHead().to("cuda")

In [None]:
# projection drawing
from dreifus.render import project
from jaxtyping import Float, UInt8
from dreifus.matrix import Pose, Intrinsics
from beartype import beartype


@beartype
def draw(
    vertices: Float[torch.Tensor, "num_vertices 3"],
    image: Float[torch.Tensor, "H W 3"],
    intrinsics: Float[torch.Tensor, "3 3"],
    world_2_cam: Float[torch.Tensor, "4 4"],
    overlay: bool = True,
) -> UInt8[np.ndarray, "H W 3"]:
    """
    Project the vertices to the image and draw them on the image.

    Args:
        vertices: The vertices to project.
        image: The image to draw on.
        world_2_cam: The world to camera transformation matrix.
        intrinsics: The camera intrinsics.

    Returns:
        The image with the vertices drawn on it.
    """
    vertices = vertices.cpu().numpy()
    image = image.cpu().numpy()
    intrinsics = Intrinsics(intrinsics.cpu().numpy())
    world_2_cam = Pose(world_2_cam.cpu().numpy())

    image = (image * 255).astype(np.uint8)
    projected = project(vertices, world_2_cam, intrinsics)
    projected_x = projected[:, 0].round().astype(int)
    projected_y = projected[:, 1].round().astype(int)
    valid_x = (0 <= projected_x) & (projected_x < image.shape[1])
    valid_y = (0 <= projected_y) & (projected_y < image.shape[0])
    valid_xy = valid_x & valid_y
    print(f"{valid_xy.sum() / projected.shape[0] * 100:.1f}% of the vertices are visible"
          f" i.e. {valid_xy.sum()} out of {projected.shape[0]}\n")
    if not overlay:
        image = np.zeros_like(image)

    for y, x in zip(projected_y[valid_xy], projected_x[valid_xy]):
        for i in range(-2, 2):
            if not 0 <= y + i < image.shape[0]:
                continue
            for j in range(-2, 2):
                if not 0 <= x + j < image.shape[1]:
                    continue
                image[y + i, x + j] = [255, 255, 255]

    return image

In [None]:
# get data
from thesis.utils import datum_to_device, apply_se3_to_point
import matplotlib.pyplot as plt

#t = np.random.randint(0, len(sm))
t = 95
cam = 1
params = sm.flame_params[t:t + 1]
params = datum_to_device(params, "cuda")
image = sm.images[t, cam].numpy()
intrinsics, extrinsics, _ = sm.cameras
extrinsics = extrinsics[cam]
se3 = sm.se3_transforms[t]

# get the vertices
vertices = flame_head(params).squeeze(0).to("cpu")
vertices = apply_se3_to_point(*se3, vertices)

# draw
image = draw(vertices, torch.tensor(image), intrinsics, extrinsics, overlay=True)
plt.figure(figsize=(12, 8))
plt.imshow(image)
plt.show()

In [None]:
# plot mesh
import trimesh

faces = flame_head.faces.cpu().numpy()

mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
mesh.show()