In [None]:
import torch
import torch.nn.functional as F
import hydra
import torch
from lib.utils.config import load_config
from lib.optimizer.framework import NeuralOptimizer
from lib.data.loader import load_intrinsics
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.utils.visualize import visualize_point2plane_error
import matplotlib.pyplot as plt
from lib.utils.visualize import visualize_merged
from PIL import Image
import numpy as np
import torch


def path_to_abblation(path):
    return "_".join(path.split("/")[-3].split("_")[1:])


def draw_and_save_color(dataset, idx, path):
    _path = f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/color/{idx:05}.png"
    img = Image.open(_path)
    img.save(path)
    # plt.imshow(img)
    # plt.axis("off")  # Hide axes
    # plt.show()


def draw_and_save_weight(flame, renderer, out, path):
    renderer.update(1)
    mask = flame.render(renderer, out["params"])["mask"][0]
    renderer.update(8)

    # weight inference
    weights = out["optim_weights"][-1]
    weights = F.interpolate(
        weights.unsqueeze(0), scale_factor=8, mode="bilinear", align_corners=False
    )
    weights = weights.detach()[0][0]
    weights[~mask] = 0.0

    plt.figure(figsize=(19.2, 10.8), dpi=100)  # Full HD size
    plt.imshow(weights.cpu().numpy())
    plt.axis("off")  # Hide axes
    plt.savefig(path, bbox_inches="tight", pad_inches=0)  # Save without padding
    plt.show()


def draw_and_save_overlay(optimizer, renderer, params, dataset, idx, path):
    _path = f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/color/{idx:05}.png"
    color = torch.tensor(np.asarray(Image.open(_path))).unsqueeze(0).to("cuda")
    renderer.update(scale=1)
    out = optimizer.flame.render(renderer, params)
    renderer.update(scale=8)
    img = visualize_merged(
        s_color=color,
        t_color=out["color"],
        t_mask=out["mask"],
    )
    img = img[0].detach().cpu().numpy()
    Image.fromarray(img).save(path)


def eval_iterations(
    optimizer,
    renderer,
    dataset,
    target_frame_idx,
    source_frame_idx,
    step_size=0.7,
    params=None,
    N=2,
):
    cfg = load_config("train", ["data=kinect"])
    datamodule = hydra.utils.instantiate(
        cfg.data,
        renderer=renderer,
        val_dataset=dict(
            start_frame=target_frame_idx,
            end_frame=target_frame_idx + 1,
            jump_size=target_frame_idx - source_frame_idx,
            datasets=[dataset],
        ),
    )
    datamodule.setup("fit")

    optimizer.max_iters = N
    optimizer.max_optims = 1
    optimizer.step_size = step_size
    out = None
    batch = None
    for i, b in enumerate(datamodule.val_dataloader()):
        with torch.no_grad():
            batch = optimizer.transfer_batch_to_device(b, "cuda", 0)
            if params is not None:
                batch["init_params"] = params
            out = optimizer(batch)
    return out, batch


def draw_and_save(img, path):
    # Display and save the error image
    Image.fromarray(img.detach().cpu().numpy()).save(path)
    # plt.imshow(img)
    # plt.axis("off")  # Hide axes
    # plt.show()


def load_flame_renderer():
    # instanciate similar to training
    cfg = load_config("train", ["data=kinect"])
    K = load_intrinsics(data_dir=cfg.data.intrinsics_dir, return_tensor="pt")
    camera = Camera(
        K=K,
        width=cfg.data.width,
        height=cfg.data.height,
        near=cfg.data.near,
        far=cfg.data.far,
        scale=cfg.data.scale,
    )
    rasterizer = Rasterizer(width=camera.width, height=camera.height)
    renderer = Renderer(rasterizer=rasterizer, camera=camera)
    flame = hydra.utils.instantiate(cfg.model)
    return flame, renderer


def render_output(renderer, optimizer, out, batch):
    renderer.update(scale=1)
    pred_out = optimizer.flame.render(renderer, out["params"])
    gt_out = optimizer.flame.render(renderer, batch["params"])
    error_map = visualize_point2plane_error(
        s_point=gt_out["point"][0],
        t_normal=pred_out["normal"][0],
        t_point=pred_out["point"][0],
        t_mask=pred_out["mask"][0],
        max_error=2e-03,  # 2mm
    )
    renderer.update(scale=8)
    color = pred_out["color"][0].detach().cpu()
    normal = pred_out["normal_image"][0].detach().cpu()
    return color, normal, error_map


def render(renderer, optimizer, out, batch):
    renderer.update(scale=1)
    pred_out = optimizer.flame.render(renderer, out["params"])
    gt_out = optimizer.flame.render(renderer, batch["params"])
    error_map = visualize_point2plane_error(
        s_point=gt_out["point"][0],
        t_normal=pred_out["normal"][0],
        t_point=pred_out["point"][0],
        t_mask=pred_out["mask"][0],
        max_error=6e-03,  # 2mm
    )
    renderer.update(scale=8)
    color = pred_out["color"][0].detach().cpu()
    return color, error_map


def load_neural_optimizer(flame, renderer, path, override=[]):
    cfg = load_config("train", ["data=kinect"] + override)
    correspondence = hydra.utils.instantiate(cfg.correspondence)
    weighting = hydra.utils.instantiate(cfg.weighting)
    residuals = hydra.utils.instantiate(cfg.residuals)
    regularize = hydra.utils.instantiate(cfg.regularize)
    neural_optimizer = NeuralOptimizer.load_from_checkpoint(
        path,
        renderer=renderer,
        flame=flame,
        correspondence=correspondence,
        regularize=regularize,
        residuals=residuals,
        weighting=weighting,
    )
    return neural_optimizer


def load_icp_optimizer(flame, renderer, overrides):
    cfg = load_config("train", ["data=kinect", "optimizer.output_dir=none"] + overrides)
    correspondence = hydra.utils.instantiate(cfg.correspondence)
    weighting = hydra.utils.instantiate(cfg.weighting)
    residuals = hydra.utils.instantiate(cfg.residuals)
    optimizer = hydra.utils.instantiate(cfg.optimizer)
    regularize = hydra.utils.instantiate(cfg.regularize)
    icp_optimizer = hydra.utils.instantiate(
        cfg.framework,
        flame=flame,
        logger=None,
        renderer=renderer,
        correspondence=correspondence,
        regularize=regularize,
        residuals=residuals,
        optimizer=optimizer,
        weighting=weighting,
    )
    return icp_optimizer.to("cuda")

def predict_mask(image, x_top, x_bottom):
    H, W, C = image.shape
    p1 = (x_top, 0)
    p2 = (x_bottom, H)

    if x_top == x_bottom:
        mask = mask = torch.zeros((H, W), dtype=torch.bool)
        mask[:, :x_top] = 1.0
        return mask

    m = ((p2[1] - p1[1]) / p2[0]) / (1 - (p1[0] / p2[0]))
    b = p1[1] - m * p1[0]
    mask = torch.zeros((H, W), dtype=torch.bool)
    for _x in range(mask.shape[1]):
        for _y in range(mask.shape[0]):
            x = (_y - b) / m
            mask[_y, _x] = x > _x
    return mask



# setup the datamodule
def load_datamodule(renderer, start_frame, end_frame, jump_size=1):
    cfg = load_config("train", ["data=kinect"])
    datamodule = hydra.utils.instantiate(
        cfg.data,
        renderer=renderer,
        val_dataset=dict(
            start_frame=start_frame, end_frame=end_frame, jump_size=jump_size
        ),
    )
    datamodule.setup("fit")
    return datamodule

In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from PIL import Image
from lib.utils.video import create_video
from tqdm.notebook import tqdm
import math
from torchvision.transforms import v2
from torchvision.transforms.functional import pil_to_tensor

for dataset in ["elias_wohlgemuth_eyeblink", "arnefucks_rotatemouth", "mykola_mouthmove", "medhansh_mouthmove"]:
    scale = 2  # 2
    # dataset = "elias_wohlgemuth_eyeblink"
    offset = 400
    margin_h = 20
    margin_w = 120

    path = Path(f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/cache")
    sequence_length = len(list((path / "2_color").iterdir()))
    mask = torch.load(path / f"{scale}_mask/{0:05}.pt")
    H, W = mask.shape

    flame, renderer = load_flame_renderer()
    override = ["residuals=face2face_wo_landmarks", "regularize=dummy", "weighting=dummy"]
    optimizer = load_icp_optimizer(flame, renderer, override)


    def predict_mask(image, x_top, x_bottom):
        H, W, C = image.shape
        p1 = (x_top, 0)
        p2 = (x_bottom, H)

        if x_top == x_bottom:
            mask = mask = torch.zeros((H, W), dtype=torch.bool)
            mask[:, :x_top] = 1.0
            return mask

        m = ((p2[1] - p1[1]) / p2[0]) / (1 - (p1[0] / p2[0]))
        b = p1[1] - m * p1[0]
        mask = torch.zeros((H, W), dtype=torch.bool)
        for _x in range(mask.shape[1]):
            for _y in range(mask.shape[0]):
                x = (_y - b) / m
                mask[_y, _x] = x > _x
        return mask


    step_size = math.ceil(2 * W / sequence_length) * 2
    xs = [*list(range(-W, W, step_size)), *list(range(W, -W, -step_size))]

    for idx, x in tqdm(enumerate(xs), total=len(xs)):
        x_bottom = x + 1
        x_top = x_bottom + offset

        cpath = f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/color/{idx:05}.png"
        color = pil_to_tensor(Image.open(cpath)).permute(1, 2, 0)
        size = (int(1080 / scale), int(1920 / scale))
        color = v2.functional.resize(
            inpt=color.permute(2, 0, 1),
            size=size,
        ).permute(1, 2, 0)

        normal_mask = torch.load(path / f"{scale}_mask/{idx:05}.pt")
        normal = torch.load(path / f"{scale}_normal/{idx:05}.pt")
        normal = (((normal + 1) / 2) * 255).to(torch.uint8)

        img_left = color.clone()
        img_right = normal

        image = img_left.clone()
        mask = predict_mask(image, x_top, x_bottom)
        image[mask & normal_mask] = img_right[mask & normal_mask]
        plt.imshow(image)

        video_dir = f"/home/borth/GuidedResearch/results/overlay/{dataset}_input/"
        frame_path = Path(video_dir) / f"{idx:05}.png"
        frame_path.parent.mkdir(parents=True, exist_ok=True)
        image = image[margin_h:-margin_h, margin_w:-margin_w, :] 
        Image.fromarray(image.detach().cpu().numpy()).save(frame_path)

        params = torch.load(
            f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/params/{idx:05}.pt"
        )
        params = optimizer.transfer_batch_to_device(params, "cuda", 0)
        video_dir = f"/home/borth/GuidedResearch/results/overlay/{dataset}_output/"
        frame_path = Path(video_dir) / f"{idx:05}.png"
        frame_path.parent.mkdir(parents=True, exist_ok=True)
        
        renderer.update(scale=2)
        out = optimizer.flame.render(renderer, params)
        renderer.update(scale=scale)
        
        img_left = color
        img_right = out["color"][0].detach().cpu()
        img_mask = out["mask"][0].detach().cpu()

        image = img_left.clone()
        image[img_mask] = img_right[img_mask]
        image = image[margin_h:-margin_h, margin_w:-margin_w, :]
        Image.fromarray(image.detach().cpu().numpy()).save(frame_path)


    video_dir = f"/home/borth/GuidedResearch/results/overlay/{dataset}_input/"
    video_path = f"/home/borth/GuidedResearch/results/overlay/{dataset}_input.mp4"
    create_video(video_dir=video_dir, video_path=video_path, framerate=16)

    video_dir = f"/home/borth/GuidedResearch/results/overlay/{dataset}_output/"
    video_path = f"/home/borth/GuidedResearch/results/overlay/{dataset}_output.mp4"
    create_video(video_dir=video_dir, video_path=video_path, framerate=16)

In [None]:
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from PIL import Image
from lib.utils.video import create_video
from tqdm.notebook import tqdm
import math
from torchvision.transforms import v2
from torchvision.transforms.functional import pil_to_tensor

scale = 2  # 2
offset = 400  # 400
margin_w = 150
margin_h = 0
dataset = "christoph_mouthmove"
setup = "dataset"

path = Path(f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/cache")
sequence_length = len(list((path / "2_color").iterdir()))
mask = torch.load(path / f"{scale}_mask/{0:05}.pt")
H, W = mask.shape

flame, renderer = load_flame_renderer()
override = ["residuals=face2face_wo_landmarks", "regularize=dummy", "weighting=dummy"]
optimizer = load_icp_optimizer(flame, renderer, override)


step_size = math.ceil(2 * W / sequence_length)
xs = [*list(range(0, W, step_size)), *list(range(W, 0, -step_size))]

for idx, x in tqdm(enumerate(xs), total=len(xs)):
    x_bottom = x + 1
    x_top = x_bottom + offset

    cpath = f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/color/{idx:05}.png"
    color = pil_to_tensor(Image.open(cpath)).permute(1, 2, 0)
    size = (int(1080 / scale), int(1920 / scale))
    color = v2.functional.resize(
        inpt=color.permute(2, 0, 1),
        size=size,
    ).permute(1, 2, 0)

    params = torch.load(
        f"/home/borth/GuidedResearch/data/dphm_kinect/{dataset}/params/{idx:05}.pt"
    )
    params = optimizer.transfer_batch_to_device(params, "cuda", 0)
    video_dir = f"/home/borth/GuidedResearch/results/{setup}/{dataset}/"
    frame_path = Path(video_dir) / f"{idx:05}.png"
    frame_path.parent.mkdir(parents=True, exist_ok=True)
    
    renderer.update(scale=2)
    out = optimizer.flame.render(renderer, params)
    renderer.update(scale=scale)
    
    # img_left = color
    # img_right = out["color"][0].detach().cpu()
    # img_mask = out["mask"][0].detach().cpu()

    # image = img_left.clone()
    # mask = predict_mask(image, x_top, x_bottom)
    # image[mask & img_mask] = img_right[mask & img_mask] 
    # image = image[margin_h:-margin_h, margin_w:-margin_w, :]
    # Image.fromarray(image.detach().cpu().numpy()).save(frame_path)

    image = out["color"][0].detach().cpu()
    video_dir = f"/home/borth/GuidedResearch/results/{setup}/{dataset}_params/"
    frame_path = Path(video_dir) / f"{idx:05}.png"
    frame_path.parent.mkdir(parents=True, exist_ok=True)
    Image.fromarray(image.detach().cpu().numpy()).save(frame_path)
    
video_dir = f"/home/borth/GuidedResearch/results/{setup}/{dataset}/"
video_path = f"/home/borth/GuidedResearch/results/{setup}/{dataset}.mp4"
create_video(video_dir=video_dir, video_path=video_path, framerate=16)