In [None]:
import torch
from torch.autograd.functional import jacobian
from torch.func import jacrev, vmap, jacfwd
from lib.trainer.timer import TimeTracker
from lib.model.layers import MLP
from tqdm import tqdm

device = "cuda"
steps = 10
n_unknowns = 159
n_residuals = 1135
hidden_dim = 1000
num_layers = 8


tracker = TimeTracker()

params = torch.randn((n_unknowns), requires_grad=True, device=device)
params_no_grad = torch.randn((n_unknowns), requires_grad=False, device=device)
mlp = MLP(
    in_dim=n_unknowns,
    out_dim=n_residuals,
    hidden_dim=hidden_dim,
    num_layers=num_layers,
).to(device)
print("MLP sum params:", sum(p.numel() for p in mlp.parameters()))


def f(x):
    residuals = mlp(x)
    return residuals


def f_no_grad(x):
    with torch.no_grad():
        residuals = mlp(x)
    return residuals


for _ in tqdm(range(steps)):
    tracker.start("reverse_jacobian_vectorize_graph")
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=True)
    tracker.start("reverse_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_jacobian_graph", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=True)
    tracker.start("reverse_jacobian", stop=True)
    jacobian(f, params, strategy="reverse-mode", vectorize=False, create_graph=False)
    tracker.start("forward_jacobian_vectorize", stop=True)
    jacobian(f, params, strategy="forward-mode", vectorize=True, create_graph=False)
    tracker.start("reverse_vmapjac", stop=True)
    vmap(jacrev(f))(params.unsqueeze(0))[0]
    tracker.start("reverse_jac", stop=True)
    jacrev(f)(params)
    tracker.start("reverse_jac_chunk10", stop=True)
    jacrev(f, chunk_size=10)(params)
    tracker.start("reverse_jac_chunk1000", stop=True)
    jacrev(f, chunk_size=1000)(params)
    tracker.start("reverse_jac_chunk10000", stop=True)
    jacrev(f, chunk_size=10000)(params)
    tracker.start("reverse_jac_chunk100000", stop=True)
    jacrev(f, chunk_size=100000)(params)
    tracker.start("reverse_jac_no_grad", stop=True)
    jacrev(f_no_grad)(params)
    tracker.start("forward_jac", stop=True)
    jacfwd(f)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.start("forward_jac_no_grad", stop=True)
    jacfwd(f_no_grad)(params)
    tracker.stop()

print(tracker.print_summary())

In [None]:
def f_comp(x):
    residuals = mlp(x)
    return residuals


@torch.compile
def jac():
    jacrev(f_comp)(params)


tracker = TimeTracker()

for _ in tqdm(range(steps)):
    tracker.start("reverse_jac")
    jac()
    tracker.stop()

print(tracker.print_summary())

In [None]:
jacrev(f)(params)

In [None]:
jacfwd(f)(params)

# FLAME

In [None]:
from lib.model.flame import Flame
import torch
import hydra
from lib.utils.config import load_config
from lib.data.datamodule import DPHMDataModule
from lib.trainer.logger import FlameLogger
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.utils.config import set_configs
from lib.optimizer.newton import GaussNewton
from lib.optimizer.pcg import PytorchSolver
import matplotlib.pyplot as plt

data_dir = "/home/borth/GuidedResearch/data/dphm_christoph_mouthmove"
flame_dir = "/home/borth/GuidedResearch/checkpoints/flame2023"

cfg = load_config(
    "optimize",
    overrides=[
        "optimizer=gauss_newton",
        "joint_trainer.init_idxs=[0]",
        "joint_trainer.max_iters=1",
        "joint_trainer.max_optims=1",
        "joint_trainer.scheduler.milestones=[0]",
        "joint_trainer.scheduler.params=[[global_pose,transl]]",
        "joint_trainer.coarse2fine.milestones=[0]",
        "joint_trainer.coarse2fine.scales=[8]",
        "sequential_trainer=null",
    ],
)
cfg = set_configs(cfg)
device = "cuda" if torch.cuda.is_available() else "cpu"

# setup camera, rasterizer and renderer
K = load_intrinsics(data_dir=cfg.data.data_dir, return_tensor="pt")
camera = Camera(
    K=K,
    width=cfg.data.width,
    height=cfg.data.height,
    near=cfg.data.near,
    far=cfg.data.far,
    scale=8,
)

rasterizer = Rasterizer(width=camera.width, height=camera.height)
renderer = Renderer(rasterizer=rasterizer, camera=camera)

# setup flame model
flame = Flame(flame_dir=flame_dir, vertices_mask="full", expression_params=50)

# datamodule fetch batch
datamodule = hydra.utils.instantiate(cfg.data, devie=device)
datamodule.update_dataset(camera=camera, rasterizer=rasterizer)
datamodule.update_idxs([0,1])
batch = datamodule.fetch()


params = {}
params["shape_params"] = torch.nn.Parameter(torch.zeros(2,100)).to(device)
params["expression_params"] = torch.nn.Parameter(torch.zeros(2,50)).to(device)
params["global_pose"] = torch.nn.Parameter(torch.tensor([[0.08, -0.24, -0.02],[0.08, -0.24, -0.02]])).to(
    device
)
params["neck_pose"] = torch.nn.Parameter(torch.zeros(2,3)).to(device)
params["jaw_pose"] = torch.nn.Parameter(torch.zeros(2,3)).to(device)
params["eye_pose"] = torch.nn.Parameter(torch.zeros(2,6)).to(device)
params["transl"] = torch.nn.Parameter(torch.tensor([[0.043, -0.003, -0.528],[0.043, -0.003, -0.528]])).to(device)
params["scale"] = torch.nn.Parameter(torch.ones(2,1)).to(device)

optimizer = GaussNewton(lin_solver=PytorchSolver())
param_groups = [{"params": [v], "p_name": k} for k, v in params.items()]
optimizer.set_param_groups(param_groups)

In [None]:
m_out = flame.model_step(
    global_pose=params["global_pose"],
    transl=params["transl"],
    neck_pose=params["neck_pose"],
    expression_params=params["expression_params"],
    shape_params=params["shape_params"],
)

r_out = flame.render_step(
    renderer=renderer,
    vertices=m_out["vertices"],
)

c_out = flame.correspondence_step(
    s_point=batch["point"],
    s_mask=batch["mask"],
    s_normal=batch["normal"],
    t_point=r_out["point"],
    t_mask=r_out["r_mask"],
    t_normal=r_out["normal"],
)
mask = c_out["mask"]
mask.sum()

In [None]:
from torch.autograd.functional import jacobian
from torch.func import jacrev, vmap, jacfwd
import torch.utils.benchmark as benchmark


def jacobian():
    m_out = flame.model_step(
        global_pose=params["global_pose"],
        transl=params["transl"],
        neck_pose=params["neck_pose"],
        expression_params=params["expression_params"],
        shape_params=params["shape_params"],
    )

    r_out = flame.render_step(
        renderer=renderer,
        vertices=m_out["vertices"],
    )
    c_out = flame.correspondence_step(
        s_point=batch["point"],
        s_mask=batch["mask"],
        s_normal=batch["normal"],
        t_point=r_out["point"],
        t_mask=r_out["r_mask"],
        t_normal=r_out["normal"],
    )
    mask = c_out["mask"]

    def closure(global_pose, transl, neck_pose, scale, expression_params, shape_params):
        m_out = flame.model_step(
            global_pose=global_pose,
            transl=transl,
            expression_params=expression_params,
            scale=scale,
            neck_pose=neck_pose,
            shape_params=shape_params,
        )
        s_point = batch["point"][mask]
        t_normal = r_out["normal"][mask]
        t_point = renderer.mask_interpolate(
            vertices_idx=r_out["vertices_idx"],
            bary_coords=r_out["bary_coords"],
            attributes=m_out["vertices"],
            mask=c_out["mask"],
        )
        point2plane = ((s_point - t_point) * t_normal).sum(-1)  # (C)
        regularization = expression_params.flatten()
        F = torch.cat([point2plane, regularization])
        return F, F

    jacobian_fn = vmap(jacfwd(closure, has_aux=True, argnums=(0,1,2,3,4,5)))
    jacobian, F = jacobian_fn(
        params["global_pose"],
        params["transl"],
        params["neck_pose"],
        params["scale"],
        params["expression_params"],
        params["shape_params"],
    )
    J = torch.cat([j.flatten(-2) for j in jacobian], dim=-1)  # (M, N)
    
    # solve for delta
    H = 2 * J.T @ J
    grad_f = J.T @ F
    delta = torch.linalg.solve(H, grad_f)

t0 = benchmark.Timer(
    stmt="jacobian()",
    setup="from __main__ import jacobian",
    globals=globals(),
)
print(t0.timeit(200))

In [None]:
p = torch.cat([
    params["global_pose"],
    params["transl"],
    params["expression_params"],
    params["neck_pose"],
    # params["shape_params"]
]).expand(1, -1)

def jacobian():
    def closure(p):
        m_out = flame.model_step(
            global_pose=p[0:3],
            transl=p[3:6],
            neck_pose=p[6:9],
            expression_params=p[9:59],
            # shape_params=p[59:159],
        )
        s_point = batch["point"][mask]
        t_normal = r_out["normal"][mask]
        t_point = renderer.mask_interpolate(
            vertices_idx=r_out["vertices_idx"],
            bary_coords=r_out["bary_coords"],
            attributes=m_out["vertices"],
            mask=c_out["mask"],
        )
        F = ((s_point - t_point) * t_normal).sum(-1)  # (C)
        J = torch.cat([F, p[9:59].flatten()])
        return J, F

    J, F = vmap(jacfwd(closure, has_aux=True))(p)
    # J, F = jacfwd(closure, has_aux=True, argnums=(0, 1, 2, 3, 4))(
    #     params["global_pose"],
    #     params["transl"],
    #     params["neck_pose"],
    #     params["expression_params"][0],
    #     params["shape_params"],
    # )
    # J = torch.cat([j.flatten(-2) for j in J], dim=-1)  # (M, N)

t0 = benchmark.Timer(
    stmt="jacobian()",
    setup="from __main__ import jacobian",
    globals=globals(),
)
print(t0.timeit(200))

In [None]:
# 17.88 (6 + 40)
# 12.77 (6 + 20)
17.88 / 46, 12.77 / 26

In [None]:
import time

s = time.time()
m_out = flame.model_step()
(time.time() - s)

In [None]:
plt.imshow(batch["color"][0].detach().cpu().numpy())

In [None]:
# def closure(
#     shape_params,
#     expression_params,
#     global_pose,
#     neck_pose,
#     jaw_pose,
#     eye_pose,
#     transl,
#     scale,
# ):
flame(
    shape_params=params["shape_params"],
    expression_params=params["expression_params"],
    global_pose=params["global_pose"],
)

In [None]:
cfg.data.width

In [None]:
flame.default_shape_params.device

In [None]:
import torch

shape_params = torch.rand((1, 100)).to("cuda")
expression_params = torch.rand((32, 50)).to("cuda")
flame(shape_params=shape_params, expression_params=expression_params).shape

In [None]:
shape_params.shape