In [None]:
from lib.model.flame.flame import Flame
import torch
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.model.weighting import ResidualWeightModule, DummyWeightModule
from lib.model.correspondence import (
    ProjectiveCorrespondenceModule,
    OpticalFlowCorrespondenceModule,
)
from lib.data.synthetic import generate_params
from lib.utils.visualize import visualize_grid, visualize_params
from lib.optimizer.residuals import Point2PlaneResiduals, VertexResiduals
from lib.optimizer.newton import GaussNewton
from lib.model.framework import FlameOptimizer, VertexOptimizer
from torch.utils.data import DataLoader
from lib.optimizer.solver import PytorchSolver

# settings
data_dir = "/home/borth/GuidedResearch/data/dphm_christoph_mouthmove"
flame_dir = "/home/borth/GuidedResearch/checkpoints/flame2023"
device = "cuda" if torch.cuda.is_available() else "cpu"

# setup camera, rasterizer and renderer
K = load_intrinsics(data_dir=data_dir, return_tensor="pt")
camera = Camera(width=640, height=480, scale=1)
rasterizer = Rasterizer(width=camera.width, height=camera.height)
renderer = Renderer(rasterizer=rasterizer, camera=camera)

# setup flame optimizer
flame = Flame(
    flame_dir=flame_dir,
    vertices_mask="full",
    expression_params=50,
    shape_params=100,
)

# creaste gt_params
gt_params = generate_params(
    flame=flame,
    window_size=1,
    default=dict(transl=[0.043, -0.003, -0.528]),
)

params = generate_params(
    flame=flame,
    window_size=1,
    default=dict(transl=[0.053, -0.013, -0.528], global_pose=[0.01, 0.1, -0.001]),
)

# initlize params close to gt_params
visualize_params(flame, renderer, gt_params, color=0)
visualize_params(flame, renderer, params, color=2)

In [None]:
from lib.utils.visualize import visualize_merged,visualize_depth_merged, visualize_grid, change_color

s_out = flame.render(renderer=renderer, params=params)
t_out = flame.render(renderer=renderer, params=gt_params)
imgs = visualize_depth_merged(
    s_color=change_color(s_out["color"], s_out["mask"], code=0),
    s_point=s_out["point"],
    s_mask=s_out["mask"],
    t_color=change_color(t_out["color"], t_out["mask"], code=1),
    t_point=t_out["point"],
    t_mask=t_out["mask"]
)
visualize_grid(imgs)

In [None]:
from collections import defaultdict



def generate_params(flame: Flame, window_size: int = 1, default: dict = {}, sigmas: dict = {}):
    # default params
    base_params = flame.generate_default_params()
    for p_name, param in default.items():
        base_params[p_name] = torch.tensor([param], device=flame.device)

    params = defaultdict(list)
    for p_name in flame.global_params:
        eps = torch.randn_like(gt_params[p_name]) * sigmas.get(p_name, 0.0)
        params[p_name] = base_params[p_name] + eps
    for p_name in flame.local_params:
        for _ in range(window_size):
            eps = torch.randn_like(base_params[p_name]) * sigmas.get(p_name, 0.0)
            param = base_params[p_name] + eps
            params[p_name].append(param)
    for p_name in flame.local_params:
        params[p_name] = torch.cat(params[p_name], dim=0)
    return params


params = generate_params(
    flame,
    window_size=4,
    default=dict(transl=[-0.0, -0.0, -0.6]),
    sigmas=dict(
        shape_params=1.0,
        expression_params=1.0,
        global_pose=0.02,
        neck_pose=0.05,
        transl=0.05,
    ),
)
offset = generate_params(
    flame,
    window_size=4,
    sigmas=dict(
        shape_params=5e-01,
        expression_params=5e-01,
        global_pose=5e-03,
        neck_pose=5e-03,
        transl=5e-03,
    ),
)
visualize_params(flame, renderer, params, color=0)

In [None]:
new_params = {}
for p_name in params:
    new_params[p_name] = params[p_name] + offset[p_name]
visualize_params(flame, renderer, new_params, color=0)

In [None]:
new_params["transl"]

In [None]:
def generate_offset(flame: Flame, window_size: int = 1, sigmas: dict = {}):
    base_params = flame.generate_default_params()
    params = defaultdict(list)
    for p_name in flame.global_params:
        eps = torch.randn_like(gt_params[p_name]) * sigmas.get(p_name, 0.0)
        params[p_name] = base_params[p_name] + eps
    for p_name in flame.local_params:
        for _ in range(window_size):
            eps = torch.randn_like(base_params[p_name]) * sigmas.get(p_name, 0.0)
            param = base_params[p_name] + eps
            params[p_name].append(param)
    for p_name in flame.local_params:
        params[p_name] = torch.cat(params[p_name], dim=0)
    return params

In [None]:
gt_params["transl"] = torch.tensor([[0.0, 0.0, -0.5]], device=device)