In [None]:
from condorgmm import Pose
from condorgmm.data import Frame
import condorgmm.warp_gmm as warp_gmm
import warp as wp
import condorgmm
import numpy as np
import condorgmm
import matplotlib.pyplot as plt

In [None]:
condorgmm.rr_init("scannet")
video = condorgmm.data.ScanNetVideo(0)

learning_rates = wp.array(
    [0.001, 0.001, 0.001, 0.0004, 0.0004, 0.0004, 0.0004], dtype=wp.float32
)

frame = video[0]
STRIDE = 25
frame_warp = frame.as_warp()
camera_pose = condorgmm.Pose(frame.camera_pose)

hyperparams = warp_gmm.state.Hyperparams(
    outlier_probability=0.99,
    outlier_volume=1e4,
    window_half_width=7,
)


spatial_means = condorgmm.xyz_from_depth_image(
    frame.depth.astype(np.float32), *frame.intrinsics
)[::STRIDE, ::STRIDE].reshape(-1, 3)
spatial_means = camera_pose.apply(spatial_means).astype(np.float32)
rgb_means = frame.rgb[::STRIDE, ::STRIDE].reshape(-1, 3).astype(np.float32)

gmm = warp_gmm.gmm_warp_from_numpy(spatial_means, rgb_means)
gmm.camera_posquat = wp.array(camera_pose.posquat.astype(np.float32))

warp_gmm_state = warp_gmm.initialize_state(gmm=gmm, frame=frame, hyperparams=hyperparams)

warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)

for _ in range(5):
    warp_gmm.warp_gmm_EM_step(frame_warp, warp_gmm_state)

warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
assert warp_gmm_state.gmm.is_valid()

two_prev_camera_poses = (
    warp_gmm_state.gmm.camera_posquat.numpy(),
    warp_gmm_state.gmm.camera_posquat.numpy(),
)
state = (
    warp_gmm_state,
    two_prev_camera_poses,
)

warp_gmm_state.gmm.camera_posquat.requires_grad = True
import importlib
importlib.reload(condorgmm.warp_gmm.optimize)
importlib.reload(condorgmm.warp_gmm)

In [None]:
inferred_camera_poses = {}
for T in range(len(video)):
    condorgmm.rr_set_time(T)
    frame = video[T]
    frame_warp = frame.as_warp()

    warp_gmm_state.gmm.camera_posquat.requires_grad = True
    results = warp_gmm.optimize_params(
        [warp_gmm_state.gmm.camera_posquat],
        frame_warp,
        warp_gmm_state,
        100,
        learning_rates,
        storing_stuff=True,
    )

    inferred_log_score_image = warp_gmm_state.log_score_image.numpy()
    inferred_log_score_image_sum = inferred_log_score_image.sum()
    # warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
    condorgmm.rr_log_posquat(warp_gmm_state.gmm.camera_posquat.numpy(), "inferred_pose")
    condorgmm.rr_log_posquat(video[T].camera_pose, "gt_pose")

    inferred_camera_poses[T] = warp_gmm_state.gmm.camera_posquat.numpy()

    if T > 0 and T % 5 == 0:
        camera_pose = condorgmm.Pose(warp_gmm_state.gmm.camera_posquat.numpy())
        spatial_means = condorgmm.xyz_from_depth_image(
            frame.depth.astype(np.float32), *frame.intrinsics
        )[::STRIDE, ::STRIDE].reshape(-1, 3)
        spatial_means = camera_pose.apply(spatial_means).astype(np.float32)
        rgb_means = frame.rgb[::STRIDE, ::STRIDE].reshape(-1, 3).astype(np.float32)

        gmm = warp_gmm.gmm_warp_from_numpy(spatial_means, rgb_means)
        gmm.camera_posquat = wp.array(camera_pose.posquat.astype(np.float32))
        warp_gmm_state.gmm = gmm

        for _ in range(5):
            warp_gmm.warp_gmm_EM_step(frame_warp, warp_gmm_state)


In [None]:
plt.plot(results["likelihoods"])

In [None]:

T = 2
frame = video[T]
frame_warp = frame.as_warp()

warp_gmm_state.gmm.camera_posquat.requires_grad = True
results = warp_gmm.optimize_params(
    [warp_gmm_state.gmm.camera_posquat],
    frame_warp,
    warp_gmm_state,
    500,
    learning_rates,
    storing_stuff=True,
)
print(results["likelihoods"][-20:])
plt.plot(results["likelihoods"])

inferred_log_score_image = warp_gmm_state.log_score_image.numpy()
inferred_log_score_image_sum = inferred_log_score_image.sum()
# warp_gmm.rr_log_gmm_warp(warp_gmm_state.gmm)
condorgmm.rr_log_posquat(video[T].camera_pose, "gt_pose")

warp_gmm_state.gmm.camera_posquat = wp.array(video[T].camera_pose.astype(np.float32))
warp_gmm.warp_gmm_forward(frame_warp, warp_gmm_state)
gt_log_score_image = warp_gmm_state.log_score_image.numpy()
gt_log_score_image_sum = gt_log_score_image.sum()
print("Inferred pose log score: ", inferred_log_score_image_sum)
print("GT pose log score: ", gt_log_score_image_sum)
plt.matshow(gt_log_score_image - inferred_log_score_image,cmap="bwr")
plt.colorbar()
plt.title("GT - Inferred")


In [None]:
plt.matshow(np.abs(gt_log_score_image - inferred_log_score_image) > 1e-3)

In [None]:
T = 0
condorgmm.rr_log_frame(video[T], camera_pose=video[T].camera_pose,channel=f"{T}")

In [None]:
T  = 10
condorgmm.rr_log_frame(video[T], camera_pose=video[T].camera_pose,channel=f"{T}")