In [None]:
import torch
from torch.autograd.functional import jacobian
from torch.func import jacrev, vmap, jacfwd
from lib.trainer.timer import TimeTracker
from lib.model.common import MLP
from tqdm import tqdm

device = "cuda"
steps = 10
n_unknowns = 159
n_residuals = 1135
hidden_dim = 1000
num_layers = 8


tracker = TimeTracker()

params = torch.randn((n_unknowns), requires_grad=True, device=device)
params_no_grad = torch.randn((n_unknowns), requires_grad=False, device=device)
mlp = MLP(
    in_dim=n_unknowns,
    out_dim=n_residuals,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)
print("MLP sum params:", sum(p.numel() for p in mlp.parameters()))


def f(x):
    residuals = mlp(x)
    return residuals


def f_no_grad(x):
    with torch.no_grad():
        residuals = mlp(x)
    return residuals


for _ in tqdm(range(steps)):
    tracker.start("reverse_jacobian_vectorize_graph")
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=True)
    tracker.start("reverse_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_jacobian_graph", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=True)
    tracker.start("reverse_jacobian", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=False)
    tracker.start("forward_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="forward-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_vmapjac", stop=True)
    vmap(jacrev(f))(params.unsqueeze(0))[0]
    tracker.start("reverse_jac", stop=True)
    jacrev(f)(params)
    tracker.start("reverse_jac_chunk10", stop=True)
    jacrev(f, chunk_size=10)(params)
    tracker.start("reverse_jac_chunk1000", stop=True)
    jacrev(f, chunk_size=1000)(params)
    tracker.start("reverse_jac_chunk10000", stop=True)
    jacrev(f, chunk_size=10000)(params)
    tracker.start("reverse_jac_chunk100000", stop=True)
    jacrev(f, chunk_size=100000)(params)
    tracker.start("reverse_jac_no_grad", stop=True)
    jacrev(f_no_grad)(params)
    tracker.start("forward_jac", stop=True)
    jacfwd(f)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.stop()

print(tracker.print_summary())

In [None]:
def f_comp(x):
    residuals = mlp(x)
    return residuals


@torch.compile
def jac():
    jacrev(f_comp)(params)


tracker = TimeTracker()

for _ in tqdm(range(steps)):
    tracker.start("reverse_jac")
    jac()
    tracker.stop()

print(tracker.print_summary())

In [None]:
jacrev(f)(params)

In [None]:
jacfwd(f)(params)

# FLAME

In [None]:
from lib.model.flame.flame import Flame
import torch
import hydra
from lib.utils.config import load_config
from lib.data.datamodule import DPHMDataModule
from lib.trainer.logger import FlameLogger
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.utils.config import set_configs
from lib.optimizer.newton import GaussNewton
from lib.optimizer.pcg import PytorchSolver
import matplotlib.pyplot as plt

data_dir = "/home/borth/GuidedResearch/data/dphm_christoph_mouthmove"
flame_dir = "/home/borth/GuidedResearch/checkpoints/flame2023"

cfg = load_config(
    "optimize",
    overrides=[
        "optimizer=gauss_newton",
        "joint_trainer.init_idxs=[0]",
        "joint_trainer.max_iters=1",
        "joint_trainer.max_optims=1",
        "joint_trainer.scheduler.milestones=[0]",
        "joint_trainer.scheduler.params=[[global_pose,transl]]",
        "joint_trainer.coarse2fine.milestones=[0]",
        "joint_trainer.coarse2fine.scales=[8]",
        "sequential_trainer=null",
    ],
)
cfg = set_configs(cfg)
device = "cuda" if torch.cuda.is_available() else "cpu"

# setup camera, rasterizer and renderer
K = load_intrinsics(data_dir=cfg.data.data_dir, return_tensor="pt")
camera = Camera(
    K=K,
    width=cfg.data.width,
    height=cfg.data.height,
    near=cfg.data.near,
    far=cfg.data.far,
    scale=4,
)

rasterizer = Rasterizer(width=camera.width, height=camera.height)
renderer = Renderer(rasterizer=rasterizer, camera=camera)

# setup flame model
flame = Flame(flame_dir=flame_dir, vertices_mask="full", expression_params=50)

# datamodule fetch batch
datamodule = hydra.utils.instantiate(cfg.data, devie=device)
datamodule.update_dataset(camera=camera, rasterizer=rasterizer)
datamodule.update_idxs([0,1])
batch = datamodule.fetch()


params = {}
params["shape_params"] = torch.nn.Parameter(torch.zeros(2,100)).to(device)
params["expression_params"] = torch.nn.Parameter(torch.zeros(2,50)).to(device)
params["global_pose"] = torch.nn.Parameter(torch.tensor([[0.08, -0.24, -0.02],[0.08, -0.24, -0.02]])).to(device)
params["neck_pose"] = torch.nn.Parameter(torch.zeros(2,3)).to(device)
params["jaw_pose"] = torch.nn.Parameter(torch.zeros(2,3)).to(device)
params["eye_pose"] = torch.nn.Parameter(torch.zeros(2,6)).to(device)
params["transl"] = torch.nn.Parameter(torch.tensor([[0.043, -0.003, -0.528],[0.043, -0.003, -0.528]])).to(device)
params["scale"] = torch.nn.Parameter(torch.ones(2,1)).to(device)

# params["shape_params"] = torch.nn.Parameter(torch.zeros(1,100)).to(device)
# params["expression_params"] = torch.nn.Parameter(torch.zeros(1,50)).to(device)
# params["global_pose"] = torch.nn.Parameter(torch.tensor([[0.08, -0.24, -0.02]])).to(device)
# params["neck_pose"] = torch.nn.Parameter(torch.zeros(1,3)).to(device)
# params["jaw_pose"] = torch.nn.Parameter(torch.zeros(1,3)).to(device)
# params["eye_pose"] = torch.nn.Parameter(torch.zeros(1,6)).to(device)
# params["transl"] = torch.nn.Parameter(torch.tensor([[0.043, -0.003, -0.528]])).to(device)
# params["scale"] = torch.nn.Parameter(torch.ones(1,1)).to(device)

optimizer = GaussNewton(lin_solver=PytorchSolver())
param_groups = [{"params": [v], "p_name": k} for k, v in params.items()]
optimizer.set_param_groups(param_groups)

In [None]:
def interpolate(
    vertices_idx: torch.Tensor,
    bary_coords: torch.Tensor,
    attributes: torch.Tensor,
):
    # access the vertex attributes
    B, H, W, _ = vertices_idx.shape  # (B, H, W, 3)
    _, V, D = attributes.shape  # (B, V, D)

    # Flatten the vertices_idx and bary_coords to reduce unnecessary operations
    flat_vertices_idx = vertices_idx.view(B, -1)  # (B, H*W*3)

    # Efficiently gather the vertex attributes in one step
    vertex_attributes = attributes.gather(
        1, flat_vertices_idx.unsqueeze(-1).expand(-1, -1, D)
    )  # (B, H*W*3, D)

    # Reshape gathered attributes to (B, H, W, 3, D) directly
    vertex_attributes = vertex_attributes.view(B, H, W, 3, D)

    # Perform the weighted sum using barycentric coordinates
    bary_coords = bary_coords.unsqueeze(-1)  # (B, H, W, 3, 1)
    attributes = (bary_coords * vertex_attributes).sum(dim=-2)  # (B, H, W, D)

    return attributes


def mask_interpolate(
    vertices_idx: torch.Tensor,  # (B, H, W, 3)
    bary_coords: torch.Tensor,  # (B, H, W, 3)
    attributes: torch.Tensor,  # (B, V, D)
    mask: torch.Tensor,  # (B, H, W, 3)
):
    # access the vertex attributes
    B, V, D = attributes.shape  # (B, V, D)
    vertices_offset = V * torch.arange(B, device=vertices_idx.device)
    v_idx = vertices_idx.clone()
    v_idx += vertices_offset.view(B, 1, 1, 1)  # (B, H, W, 3)
    v_idx = v_idx[mask]  # (C, 3)
    vertex_attribute = attributes.reshape(-1, D)[v_idx]  # (C, 3, D)

    bary_coords = bary_coords[mask].unsqueeze(-1)  # (C, 3, 1)
    attributes = (bary_coords * vertex_attribute).sum(-2)  # (B, H, W, D)
    return attributes

In [None]:
from torch.autograd.functional import jacobian
from torch.func import jacrev, vmap, jacfwd
import torch.utils.benchmark as benchmark
from lib.utils.mesh import vertex_normals


def jacobian():
    m_out = flame.model_step(
        global_pose=params["global_pose"],
        transl=params["transl"],
        neck_pose=params["neck_pose"],
        expression_params=params["expression_params"],
        shape_params=params["shape_params"],
    )

    r_out = flame.render_step(
        renderer=renderer,
        vertices=m_out["vertices"],
    )
    c_out = flame.correspondence_step(
        s_point=batch["point"],
        s_mask=batch["mask"],
        s_normal=batch["normal"],
        t_point=r_out["point"],
        t_mask=r_out["r_mask"],
        t_normal=r_out["normal"],
    )
    mask = c_out["mask"]

    def closure(global_pose, transl, neck_pose, scale, expression_params, shape_params):
        m_out = flame.model_step(
            global_pose=global_pose,
            transl=transl,
            expression_params=expression_params,
            scale=scale,
            neck_pose=neck_pose,
            shape_params=shape_params,
        )
        t_point = interpolate(
            vertices_idx=r_out["vertices_idx"],
            bary_coords=r_out["bary_coords"],
            attributes=m_out["vertices"],
        )
        t_point = t_point[mask]
        s_point = batch["point"][mask]
        t_normal = r_out["normal"][mask]
        point2plane = ((s_point - t_point) * t_normal).sum(-1)  # (C)
        regularization = expression_params.flatten()
        F = torch.cat([point2plane, regularization])
        return F, F

    jacobian_fn = jacfwd(closure, has_aux=True, argnums=(0,1,2,3,4,5))
    jacobian, F = jacobian_fn(
        params["global_pose"],
        params["transl"],
        params["neck_pose"],
        params["scale"],
        params["expression_params"],
        params["shape_params"],
    )
    J = torch.cat([j.flatten(-2) for j in jacobian], dim=-1)  # (M, N)
    
    # solve for delta
    H = 2 * J.T @ J
    grad_f = J.T @ F
    delta = torch.linalg.solve(H, grad_f)
    return delta

t0 = benchmark.Timer(
    stmt="jacobian()",
    setup="from __main__ import jacobian",
    globals=globals(),
)
print(t0.timeit(200))

In [None]:
m_out = flame.model_step(
    global_pose=params["global_pose"],
    transl=params["transl"],
    neck_pose=params["neck_pose"],
    expression_params=params["expression_params"],
    shape_params=params["shape_params"],
)

r_out = flame.render_step(
    renderer=renderer,
    vertices=m_out["vertices"],
)
c_out = flame.correspondence_step(
    s_point=batch["point"],
    s_mask=batch["mask"],
    s_normal=batch["normal"],
    t_point=r_out["point"],
    t_mask=r_out["r_mask"],
    t_normal=r_out["normal"],
)
mask = c_out["mask"]

In [None]:
s_mask = batch["mask"]
t_mask = r_out["r_mask"]
bary_coords = r_out["bary_coords"].clone()
bary_coords[~mask] = torch.nan 
t_point = interpolate(
    vertices_idx=r_out["vertices_idx"],
    bary_coords=bary_coords,
    attributes=m_out["vertices"],
)

B, H, W, C = t_point.shape
x = torch.linspace(-1, 1, steps=W, device=t_point.device)
x += torch.rand_like(x) * 0.01
y = torch.linspace(-1, 1, steps=H, device=t_point.device)
y += torch.rand_like(y) * 0.01
y, x = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack([y, x], dim=-1).expand(B, H, W, 2)

t_point = batch["point"].permute(0, 3, 1, 2)
samples = torch.nn.functional.grid_sample(
    input=t_point,
    grid=grid,
    mode="bilinear",
    align_corners=False,
    padding_mode="zeros",
)
samples = samples.permute(0, 2, 3, 1)
mask = ~torch.isnan(samples)


In [None]:
def optical_flow(
    s_point: torch.Tensor,
    t_point: torch.Tensor,
):
    B, H, W, C = s_point.shape
    x = torch.linspace(-1, 1, steps=W, device=s_point.device)
    x = torch.rand_like(x) * 0.01
    y = torch.linspace(-1, 1, steps=H, device=s_point.device)
    y = torch.rand_like(y) * 0.01
    y, x = torch.meshgrid(y, x, indexing="ij")
    delta = torch.stack([y, x], dim=-1).expand(B, H, W, 2)
    return delta


def weight():
    x = torch.linspace(-1, 1, steps=W, device=t_point.device)
    x = torch.rand_like(x) * 0.01
    y = torch.linspace(-1, 1, steps=H, device=t_point.device)
    y = torch.rand_like(y) * 0.01
    y, x = torch.meshgrid(y, x, indexing="ij")
    return x.expand(B, H, W)

def correspondences(
    s_delta: torch.Tensor,
    s_mask: torch.Tensor,
    t_value: torch.Tensor,
    t_mask: torch.Tensor,
):
    B, H, W, C = t_value.shape

    # create the pixel grid
    x = torch.linspace(-1, 1, steps=W, device=t_point.device)
    y = torch.linspace(-1, 1, steps=H, device=t_point.device)
    y, x = torch.meshgrid(y, x, indexing="ij")
    grid = torch.stack([y, x], dim=-1).expand(B, H, W, 2)

    # add the optical flow delta between (-1, 1)
    grid = grid + s_delta

    # the values that are not in the mask needs to be marked diry for interpolation
    # we need to prepare for bilinear interpolation
    value = t_value.clone()
    value[~t_mask] = torch.nan 
    value = value.permute(0, 3, 1, 2)

    samples = torch.nn.functional.grid_sample(
        input=value,
        grid=grid,
        mode="bilinear",
        align_corners=False,
        padding_mode="zeros",
    )
    samples = samples.permute(0, 2, 3, 1)
    
    # update the target mask
    t_new_mask = ~torch.isnan(samples[..., 0])

    # compute the mask, where the source points found a correspondence point
    mask = t_new_mask & s_mask 
    values = samples.nan_to_num(0.0)

    return mask, values

s_delta = optical_flow(s_point=batch["point"], t_point=r_out["point"])
s_mask = batch["mask"]
t_mask = r_out["r_mask"]
t_value = interpolate(
    vertices_idx=r_out["vertices_idx"],
    bary_coords=r_out["bary_coords"],
    attributes=m_out["vertices"],
)
m, v = correspondences(
    s_delta = s_delta,
    s_mask = s_mask,
    t_value = t_value,
    t_mask = t_mask,
)
s_mask.sum(), t_mask.sum(), m.sum()

In [None]:
B, H, W, C = t_point.shape

x = torch.linspace(-1, 1, steps=W, device=t_point.device)
x += torch.rand_like(x) * 0.01
y = torch.linspace(-1, 1, steps=H, device=t_point.device)
y += torch.rand_like(y) * 0.01
y, x = torch.meshgrid(y, x, indexing="ij")
grid = torch.stack([y, x], dim=-1).expand(B, H, W, 2)

t_point = batch["mask"].unsqueeze(-1).permute(0, 3, 1, 2)
samples = torch.nn.functional.grid_sample(
    input=t_point,
    grid=grid,
    mode="bilinear",
    align_corners=False,
    padding_mode="zeros",
)
samples = samples.permute(0, 2, 3, 1)
mask = ~torch.isnan(samples)
samples

In [None]:
def foo():
    batch["point"][mask]

t0 = benchmark.Timer(
    stmt="foo()",
    setup="from __main__ import foo",
    globals=globals(),
)
print(t0.timeit(1000))

In [None]:
def interpolate(
    vertices_idx: torch.Tensor,
    bary_coords: torch.Tensor,
    attributes: torch.Tensor,
):
    B, H, W, _ = vertices_idx.shape  # (B, H, W, 3)
    _, V, D = attributes.shape  # (B, V, D)

    # Flatten the vertices_idx to use it for indexing attributes directly
    flat_vertices_idx = vertices_idx.view(B, -1)  # (B, H*W*3)
    
    # Efficiently gather the vertex attributes
    vertex_attributes = attributes.gather(1, flat_vertices_idx.unsqueeze(-1).expand(-1, -1, D))  # (B, H*W*3, D)
    
    # Reshape to get the desired shape (B, H, W, 3, D)
    vertex_attributes = vertex_attributes.view(B, H, W, 3, D)

    # Add a dimension to bary_coords for broadcasting
    bary_coords = bary_coords.unsqueeze(-1)  # (B, H, W, 3, 1)
    
    # Perform the weighted sum
    interpolated_attributes = (bary_coords * vertex_attributes).sum(dim=-2)  # (B, H, W, D)
    
    return interpolated_attributes

def interpolate1(
    vertices_idx: torch.Tensor,
    bary_coords: torch.Tensor,
    attributes: torch.Tensor,
):
    # access the vertex attributes
    B, H, W, _ = vertices_idx.shape  # (B, H, W, 3)
    _, _, D = attributes.shape  # (B, V, D)
    v_idx = vertices_idx.clone()
    v_idx = v_idx.reshape(B, -1)  # (B, *)
    b_idx = torch.arange(v_idx.size(0), device=v_idx.device).unsqueeze(1)
    vertex_attribute = attributes[b_idx, v_idx]  # (B, *, D)
    vertex_attribute = vertex_attribute.reshape(B, H, W, 3, D)  # (B, H, W, 3, D)

    bary_coords = bary_coords.unsqueeze(-1)  # (B, H, W, 3, 1)
    attributes = (bary_coords * vertex_attribute).sum(-2)  # (B, H, W, D)
    return attributes

def interpolate2(
    vertices_idx: torch.Tensor,
    bary_coords: torch.Tensor,
    attributes: torch.Tensor,
):
    B, H, W, _ = vertices_idx.shape  # (B, H, W, 3)
    _, V, D = attributes.shape  # (B, V, D)

    # Flatten the vertices_idx and bary_coords to reduce unnecessary operations
    flat_vertices_idx = vertices_idx.view(B, -1)  # (B, H*W*3)

    # Efficiently gather the vertex attributes in one step
    vertex_attributes = attributes.gather(1, flat_vertices_idx.unsqueeze(-1).expand(-1, -1, D))  # (B, H*W*3, D)

    # Reshape gathered attributes to (B, H, W, 3, D) directly
    vertex_attributes = vertex_attributes.view(B, H, W, 3, D)

    # Perform the weighted sum using barycentric coordinates
    interpolated_attributes = (bary_coords.unsqueeze(-1) * vertex_attributes).sum(dim=-2)  # (B, H, W, D)

    return interpolated_attributes

def mask_interpolate(
    vertices_idx: torch.Tensor,  # (B, H, W, 3)
    bary_coords: torch.Tensor,  # (B, H, W, 3)
    attributes: torch.Tensor,  # (B, V, D)
    mask: torch.Tensor,  # (B, H, W, 3)
):
    # access the vertex attributes
    B, V, D = attributes.shape  # (B, V, D)
    vertices_offset = V * torch.arange(B, device=vertices_idx.device)
    v_idx = vertices_idx.clone()
    v_idx += vertices_offset.view(B, 1, 1, 1)  # (B, H, W, 3)
    v_idx = v_idx[mask]  # (C, 3)
    vertex_attribute = attributes.reshape(-1, D)[v_idx]  # (C, 3, D)

    bary_coords = bary_coords[mask].unsqueeze(-1)  # (C, 3, 1)
    attributes = (bary_coords * vertex_attribute).sum(-2)  # (B, H, W, D)
    return attributes

def foo():
    interpolate2(
        vertices_idx=r_out["vertices_idx"],
        bary_coords=r_out["bary_coords"],
        attributes=m_out["vertices"],
    )[mask]

# def foo():
#     mask_interpolate(
#         vertices_idx=r_out["vertices_idx"],
#         bary_coords=r_out["bary_coords"],
#         attributes=m_out["vertices"],
#         mask=mask
#     )

t0 = benchmark.Timer(
    stmt="foo()",
    setup="from __main__ import foo",
    globals=globals(),
)
print(t0.timeit(1000))

In [None]:
mask = r_out["r_mask"]
bary_coords = r_out["bary_coords"]
vertices_idx = r_out["vertices_idx"]
attibutes = m_out["vertices"]

B, H, W, C = bary_coords.shape
coords = torch.tensor([[[121.7, 42.3], [122.7, 45.3], [121.7, 42.3]]] * B).to("cuda")  # Example batch of coordinates
_, N, _ = coords.shape

# Separate the coordinates into x and y parts
x, y = coords[..., 0], coords[..., 1]

# Get the integer part of the coordinates
x0 = torch.floor(x).long()
x1 = x0 + 1
y0 = torch.floor(y).long()
y1 = y0 + 1

x0 = torch.clamp(x0, 0, W - 1)  # (B, N, 3)
x1 = torch.clamp(x1, 0, W - 1)  # (B, N, 3)
y0 = torch.clamp(y0, 0, H - 1)  # (B, N, 3)
y1 = torch.clamp(y1, 0, H - 1)  # (B, N, 3)

batch_indices = torch.arange(B, dtype=torch.long).view(-1, 1, 1).expand(-1, N, -1).to("cuda")

# Gather pixel values at the corners
Ia = bary_coords[batch_indices, y0.unsqueeze(-1), x0.unsqueeze(-1)].squeeze(2)  # top-left
Ib = bary_coords[batch_indices, y1.unsqueeze(-1), x0.unsqueeze(-1)].squeeze(2)  # bottom-left
Ic = bary_coords[batch_indices, y0.unsqueeze(-1), x1.unsqueeze(-1)].squeeze(2)  # top-right
Id = bary_coords[batch_indices, y1.unsqueeze(-1), x1.unsqueeze(-1)].squeeze(2)  # bottom-right

# Get the fractional part of the coordinates
wa = (x1.float() - x) * (y1.float() - y)
wb = (x1.float() - x) * (y - y0.float())
wc = (x - x0.float()) * (y1.float() - y)
wd = (x - x0.float()) * (y - y0.float())

interpolated_values = (wa.unsqueeze(-1) * Ia + 
                        wb.unsqueeze(-1) * Ib + 
                        wc.unsqueeze(-1) * Ic + 
                        wd.unsqueeze(-1) * Id)

Ia = vertices_idx[batch_indices, y0.unsqueeze(-1), x0.unsqueeze(-1)].squeeze(2)  # top-left
Ib = vertices_idx[batch_indices, y1.unsqueeze(-1), x0.unsqueeze(-1)].squeeze(2)  # bottom-left
Ic = vertices_idx[batch_indices, y0.unsqueeze(-1), x1.unsqueeze(-1)].squeeze(2)  # top-right
Id = vertices_idx[batch_indices, y1.unsqueeze(-1), x1.unsqueeze(-1)].squeeze(2)  # bottom-right

# batch_indices.shape, x0.shape
# bary_coords[torch.tensor([0]).view(1, 1, 1), torch.tensor([0, 1]).view(1, 2, 1), torch.tensor([0, 1, 3]).view(1, 1, 3)].shape
# bary_coords[torch.tensor([0, 1]).view(2, 1, 1), y0].shape
# Ia = bary_coords[batch_indices, y0.unsqueeze(-1), x0.unsqueeze(-1)].squeeze(2)  # top-left
# Ia.shape
Ia, Ib, Ic, Id

In [None]:
p = torch.cat([
    params["global_pose"],
    params["transl"],
    params["expression_params"],
    params["neck_pose"],
    # params["shape_params"]
]).expand(1, -1)

def jacobian():
    def closure(p):
        m_out = flame.model_step(
            global_pose=p[0:3],
            transl=p[3:6],
            neck_pose=p[6:9],
            expression_params=p[9:59],
            shape_params=p[59:159],
        )
        s_point = batch["point"][mask]
        t_normal = r_out["normal"][mask]
        t_point = renderer.mask_interpolate(
            vertices_idx=r_out["vertices_idx"],
            bary_coords=r_out["bary_coords"],
            attributes=m_out["vertices"],
            mask=c_out["mask"],
        )
        F = ((s_point - t_point) * t_normal).sum(-1)  # (C)
        J = torch.cat([F, p[9:59].flatten()])
        return J, F

    J, F = vmap(jacfwd(closure, has_aux=True))(p)
    # J, F = jacfwd(closure, has_aux=True, argnums=(0, 1, 2, 3, 4))(
    #     params["global_pose"],
    #     params["transl"],
    #     params["neck_pose"],
    #     params["expression_params"][0],
    #     params["shape_params"],
    # )
    # J = torch.cat([j.flatten(-2) for j in J], dim=-1)  # (M, N)

t0 = benchmark.Timer(
    stmt="jacobian()",
    setup="from __main__ import jacobian",
    globals=globals(),
)
print(t0.timeit(200))

In [None]:
# 17.88 (6 + 40)
# 12.77 (6 + 20)
17.88 / 46, 12.77 / 26

In [None]:
import time

s = time.time()
m_out = flame.model_step()
(time.time() - s)

In [None]:
plt.imshow(batch["color"][0].detach().cpu().numpy())

In [None]:
# def closure(
#     shape_params,
#     expression_params,
#     global_pose,
#     neck_pose,
#     jaw_pose,
#     eye_pose,
#     transl,
#     scale,
# ):
flame(
    shape_params=params["shape_params"],
    expression_params=params["expression_params"],
    global_pose=params["global_pose"],
)

In [None]:
cfg.data.width

In [None]:
flame.default_shape_params.device

In [None]:
import torch

shape_params = torch.rand((1, 100)).to("cuda")
expression_params = torch.rand((32, 50)).to("cuda")
flame(shape_params=shape_params, expression_params=expression_params).shape

In [None]:
shape_params.shape