In [None]:
import hydra
import torch
from tqdm.notebook import tqdm
from lib.utils.config import load_config
from lib.optimizer.framework import NeuralOptimizer
from lib.data.loader import load_intrinsics
from lib.data.loader import load_intrinsics
from lib.rasterizer import Rasterizer
from lib.renderer.renderer import Renderer
from lib.renderer.camera import Camera
from lib.tracker.timer import TimeTracker

# settings
N = 2
value = "loss_param"
path = "/home/borth/GuidedResearch/logs/2024-10-01/15-13-21_train/checkpoints/epoch_029.ckpt"
start_frame = 10
end_frame = 12

# instanciate similar to training
cfg = load_config("train", ["data=synthetic"])
K = load_intrinsics(data_dir=cfg.data.intrinsics_dir, return_tensor="pt")
camera = Camera(
    K=K,
    width=cfg.data.width,
    height=cfg.data.height,
    near=cfg.data.near,
    far=cfg.data.far,
    scale=cfg.data.scale,
)
rasterizer = Rasterizer(width=camera.width, height=camera.height)
renderer = Renderer(rasterizer=rasterizer, camera=camera)
flame = hydra.utils.instantiate(cfg.model)

# setup the neural optimizer
correspondence = hydra.utils.instantiate(cfg.correspondence)
weighting = hydra.utils.instantiate(cfg.weighting)
residuals = hydra.utils.instantiate(cfg.residuals)
regularize = hydra.utils.instantiate(cfg.regularize)
neural_optimizer = NeuralOptimizer.load_from_checkpoint(
    path,
    renderer=renderer,
    flame=flame,
    correspondence=correspondence,
    regularize=regularize,
    residuals=residuals,
    weighting=weighting,
)

# setup icp optimizer
cfg = load_config("train", ["data=synthetic", "residuals=face2face", "weighting=dummy", "regularize=dummy","optimizer.output_dir=none"])
correspondence = hydra.utils.instantiate(cfg.correspondence)
weighting = hydra.utils.instantiate(cfg.weighting)
residuals = hydra.utils.instantiate(cfg.residuals)
optimizer = hydra.utils.instantiate(cfg.optimizer)
regularize = hydra.utils.instantiate(cfg.regularize)
icp_optimizer = hydra.utils.instantiate(
    cfg.framework,
    flame=flame,
    logger=None,
    renderer=renderer,
    correspondence=correspondence,
    regularize=regularize,
    residuals=residuals,
    optimizer=optimizer,
    weighting=weighting,
)

# setup the datamodule
datamodule = hydra.utils.instantiate(
    cfg.data,
    renderer=renderer,
    val_dataset=dict(
        start_frame=start_frame,
        end_frame=end_frame,
    ),
)
datamodule.setup("fit")
dataloader = datamodule.val_dataloader()

In [None]:
from lib.utils.progress import reset_progress, close_progress

def eval_iterations(optimizer, N: int = 1, value: str = "loss_param"):
    outer_progress = tqdm(total=N+1, desc="Iter Loop", position=0)
    total_evals = len(datamodule.val_dataset)
    inner_progress = tqdm(total=total_evals, desc="Eval Loop", leave=True, position=1)
    iters_loss = {}
    iters_time = {}

    # initial evaluation no optimization
    reset_progress(inner_progress, total_evals)
    loss = []
    for batch in dataloader:
        with torch.no_grad():
            batch = optimizer.transfer_batch_to_device(batch, "cuda", 0)
            out = optimizer(batch)
            out["params"] = batch["init_params"]
            loss_info = optimizer.compute_loss(batch=batch, out=out)
            loss.append(loss_info[value])
        inner_progress.update(1)
    iters_loss[0] = torch.stack(loss)
    iters_time[0] = torch.zeros_like(iters_loss[0])
    outer_progress.update(1)
        
    # evaluation after some optimization
    for iters in range(1, N+1):
        reset_progress(inner_progress, total_evals)
        optimizer.max_iters = iters
        time_tracker = TimeTracker()
        loss = []
        for batch in dataloader:
            with torch.no_grad():
                batch = optimizer.transfer_batch_to_device(batch, "cuda", 0)
                time_tracker.start("optimize")
                out = optimizer(batch)
                time_tracker.stop("optimize")
                loss_info = optimizer.compute_loss(batch=batch, out=out)
                loss.append(loss_info[value])
            inner_progress.update(1)
        loss = torch.stack(loss)
        iters_loss[iters] = loss
        iters_time[iters] = torch.stack([torch.tensor(t.time_ms) for t in list(time_tracker.tracks.values())[0]])
        outer_progress.update(1)
    close_progress([outer_progress, inner_progress])
    return iters_loss, iters_time

icp_loss, icp_time = eval_iterations(icp_optimizer, N=N, value=value) 
neural_loss, neural_time = eval_iterations(neural_optimizer, N=N, value=value)

In [None]:
print(neural_optimizer.time_tracker.print_summary())

In [None]:
print(icp_optimizer.time_tracker.print_summary())

In [None]:
print({k: v.median() for k, v in neural_time.items()})
print({k: v.median() for k, v in icp_time.items()})

In [None]:
print({k: v.mean() for k, v in neural_loss.items()})
print({k: v.mean() for k, v in icp_loss.items()})

In [None]:
icp_loss