# FLAME 3D visualization

In [None]:
%load_ext autoreload
%autoreload 2

import os
import numpy as np
import torch

if 'has_been_executed' not in locals():
    print("Changing directory to the root of the project")
    os.chdir("../../../../")
    has_been_executed = True
else:
    print("The directory is correct")

In [None]:
from thesis.datasets import get_data_loader
from thesis.helpers import load_config

config = load_config("configs/config.yaml")
data_loader = get_data_loader(
    "nersemble",
    "SEN-01-cramp_small_danger",
    time_step=55,
    batch_size=1,
    data_keys=["image", "camera", "flame_params"],
)
iterator = iter(data_loader)

In [None]:
from thesis.external.flame import FlameHead

flame_head = FlameHead()
flame_head = flame_head.to("cuda")

In [None]:
try:
    batch = next(iterator)
except StopIteration:
    iterator = iter(data_loader)
    batch = next(iterator)

In [None]:
vertices = flame_head(**batch["flame_params"])[0]
faces = flame_head.faces
vertices = vertices.cpu().detach().numpy()
faces = faces.cpu().detach().numpy()
# faces are (n, 3) shape, but they should be (n, 4) where the first element
# is the number of vertices. Add threes to the faces.
faces = np.hstack((np.ones((faces.shape[0], 1), dtype=np.int64) * 3, faces))

In [None]:
import pyvista as pv


def render_mesh_offscreen(vertices, faces, image_size=(1024, 768)):
    # Create a PyVista PolyData object
    mesh = pv.PolyData(vertices, faces)
    # Create a plotter for off-screen rendering
    plotter = pv.Plotter(off_screen=True)
    # Set the window size
    plotter.window_size = image_size
    # Add the mesh to the plotter
    plotter.add_mesh(
        mesh,
        lighting=True,
        smooth_shading=True,
        show_edges=False,
    )
    # Render the image
    image = plotter.screenshot(
        transparent_background=True,
        return_img=True,
    )
    # Close the plotter to free up memory
    plotter.close()
    return image


# plot image
image = render_mesh_offscreen(vertices, faces)
import matplotlib.pyplot as plt

plt.imshow(image)