In [6]:
import hydra
import torch
from lib.utils.config import load_config
from lib.optimizer.framework import NeuralOptimizer
from lib.data.loader import load_intrinsics
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.utils.visualize import visualize_point2plane_error
import matplotlib.pyplot as plt
import torch.nn.functional as F

def path_to_abblation(path):
    return "_".join(path.split("/")[-3].split("_")[1:])



def draw_and_save_weight(flame, renderer, out, path):
    renderer.update(1)
    mask = flame.render(renderer, out["params"])["mask"][0]
    renderer.update(8)

    # weight inference
    weights = out["optim_weights"][-1]
    weights = F.interpolate(
        weights.unsqueeze(0), scale_factor=8, mode="bilinear", align_corners=False
    )
    weights = weights.detach()[0][0]
    weights[~mask] = 0.0

    # plt.figure(figsize=(19.2, 10.8), dpi=100)  # Full HD size
    plt.imshow(weights.cpu().numpy())
    plt.axis("off")  # Hide axes
    plt.savefig(path, bbox_inches="tight", pad_inches=0)  # Save without padding
    plt.show()


def eval_iterations(optimizer, datamodule, idx: int):
    out = None
    batch = None
    for i, b in enumerate(datamodule.val_dataloader()):
        if i == idx:
            with torch.no_grad():
                batch = optimizer.transfer_batch_to_device(b, "cuda", 0)
                out = optimizer(batch)
    return out, batch

def draw_and_save(img, path):
    # Display and save the error image
    plt.figure(figsize=(19.2, 10.8), dpi=100)  # Full HD size
    plt.imshow(img)
    plt.axis('off')  # Hide axes
    plt.savefig(path, bbox_inches='tight', pad_inches=0)  # Save without padding
    plt.show()

def load_flame_renderer():
    # instanciate similar to training
    cfg = load_config("train", ["data=synthetic"])
    K = load_intrinsics(data_dir=cfg.data.intrinsics_dir, return_tensor="pt")
    camera = Camera(
        K=K,
        width=cfg.data.width,
        height=cfg.data.height,
        near=cfg.data.near,
        far=cfg.data.far,
        scale=cfg.data.scale,
    )
    rasterizer = Rasterizer(width=camera.width, height=camera.height)
    renderer = Renderer(rasterizer=rasterizer, camera=camera)
    flame = hydra.utils.instantiate(cfg.model)
    return flame, renderer


def load_neural_optimizer(flame, renderer, path, override=[]):
    cfg = load_config("train", ["data=synthetic"] + override)
    correspondence = hydra.utils.instantiate(cfg.correspondence)
    weighting = hydra.utils.instantiate(cfg.weighting)
    residuals = hydra.utils.instantiate(cfg.residuals)
    regularize = hydra.utils.instantiate(cfg.regularize)
    neural_optimizer = NeuralOptimizer.load_from_checkpoint(
        path,
        renderer=renderer,
        flame=flame,
        correspondence=correspondence,
        regularize=regularize,
        residuals=residuals,
        weighting=weighting,
    )
    return neural_optimizer


def render_output(renderer, optimizer, out, batch):
    renderer.update(scale=1)
    pred_out = optimizer.flame.render(renderer, out["params"])
    gt_out = optimizer.flame.render(renderer, batch["params"])
    error_map = visualize_point2plane_error(
        s_point=gt_out["point"][0],
        t_normal=pred_out["normal"][0],
        t_point=pred_out["point"][0],
        t_mask=pred_out["mask"][0],
        max_error=2e-03,  # 2mm
    )
    renderer.update(scale=8)
    color = pred_out["color"][0].detach().cpu()
    return color, error_map


def load_icp_optimizer(flame, renderer, overrides):
    cfg = load_config(
        "train", ["data=synthetic", "optimizer.output_dir=none"] + overrides
    )
    correspondence = hydra.utils.instantiate(cfg.correspondence)
    weighting = hydra.utils.instantiate(cfg.weighting)
    residuals = hydra.utils.instantiate(cfg.residuals)
    optimizer = hydra.utils.instantiate(cfg.optimizer)
    regularize = hydra.utils.instantiate(cfg.regularize)
    icp_optimizer = hydra.utils.instantiate(
        cfg.framework,
        flame=flame,
        logger=None,
        renderer=renderer,
        correspondence=correspondence,
        regularize=regularize,
        residuals=residuals,
        optimizer=optimizer,
        weighting=weighting,
    )
    return icp_optimizer.to("cuda")


# setup the datamodule
def load_datamodule(renderer, start_frame, end_frame, jump_size=1):
    cfg = load_config("train", ["data=synthetic"])
    datamodule = hydra.utils.instantiate(
        cfg.data,
        renderer=renderer,
        val_dataset=dict(
            start_frame=start_frame,
            end_frame=end_frame,
            jump_size=jump_size,
            landmarks=False,
        ),
    )
    datamodule.setup("fit")
    return datamodule

In [None]:
# settings
step_size = 0.5

idx = 5
start_frame = 10
end_frame = 11

setup = "weight_prior"

# loadings
flame, renderer = load_flame_renderer()
datamodule = load_datamodule(renderer, start_frame, end_frame, 1)

# checkpoints
ours = "/home/borth/GuidedResearch/checkpoints/synthetic_lr/wo_neural_prior.ckpt"
path = ours
override = ["residuals=neural", "regularize=dummy", "weighting=unet", "weighting.size=256"]
optimizer = load_neural_optimizer(flame, renderer, path, override)

for N in [1,2,3,4,5]:
    optimizer.optimizer.step_size = step_size
    optimizer.optimizer.max_iters = N
    out, batch = eval_iterations(optimizer, datamodule, idx)
    color, error = render_output(renderer, optimizer, out, batch)
    draw_and_save(color, f"results/{setup}/color_{N}.png")
    draw_and_save(error, f"results/{setup}/error_{N}.png")
    draw_and_save_weight(flame, renderer, out, f"results/{setup}/weight_{N}.png")

In [None]:
optimizer.optimizer.max_iters