In [None]:
import bz2
import urllib.request
from pathlib import Path

import numba as nb
import numpy as np
import pyceres as crs
import rerun as rr
import scipy.spatial.transform as spt
import wrenfold as wf
from wrenfold.geometry import Quaternion

# Bundle Adjustment in the Large 
https://grail.cs.washington.edu/projects/bal/

```
<num_cameras> <num_points> <num_observations>
<camera_index_1> <point_index_1> <x_1> <y_1>
...
<camera_index_num_observations> <point_index_num_observations> <x_num_observations> <y_num_observations>
<camera_1>
...
<camera_num_cameras>
<point_1>
...
<point_num_points>
```

In [None]:
bal_link = (
    "https://grail.cs.washington.edu/projects/bal/data/trafalgar/problem-21-11315-pre.txt.bz2"
)


bz2_file = Path(bal_link).name

urllib.request.urlretrieve(bal_link, bz2_file)
print(f"Downloaded \t{bz2_file}")

txt_file = Path(bz2_file).stem
with bz2.open(bz2_file, "rb") as f_in, open(txt_file, "wb") as f_out:
    f_out.write(f_in.read())
print(f"Extracted \t{txt_file}")

In [None]:
# Parse the bundle adjustment problem txt file
NUM_OBSERVATION_DIS = 2
NUM_CAMERA_PARAMS = 9
NUM_POINT_PARAMS = 3


with Path(txt_file).open("r") as f:
    # Read the first line to parse num_cameras, num_points, num_observations
    first_line = f.readline().strip()
    num_cameras, num_points, num_observations = map(int, first_line.split())
    print(f"{num_cameras=}, {num_points=}, {num_observations=}")

    camera_inds = np.zeros(num_observations, dtype=np.int32)
    point_inds = np.zeros(num_observations, dtype=np.int32)
    observations = np.zeros((num_observations, NUM_OBSERVATION_DIS), dtype=np.float64)
    # rvec_cam_world (3), trans_cam_world (3), f, k1, k2
    camera_params = np.zeros((num_cameras, NUM_CAMERA_PARAMS), dtype=np.float64)
    point_params = np.zeros((num_points, NUM_POINT_PARAMS), dtype=np.float64)

    for i in range(num_observations):
        line = f.readline().strip()
        camera_ind, point_ind, obs_x, obs_y = line.split()
        camera_inds[i] = int(camera_ind)
        point_inds[i] = int(point_ind)
        observations[i] = (float(obs_x), float(obs_y))

    for i in range(num_cameras * 9):
        line = f.readline().strip()
        camera_params.flat[i] = float(line)

    for i in range(num_points * 3):
        line = f.readline().strip()
        point_params.flat[i] = float(line)

In [None]:
# Write the wrenfold symbolic reprojection error cost function


def snavely_reprojection_error(
    obs: wf.Vector2, r_c_w: wf.Vector3, t_c_w: wf.Vector3, f_k1_k2: wf.Vector3, p_w: wf.Vector3
):
    """
    https://grail.cs.washington.edu/projects/bal/
    P  =  R * X + t       (conversion from world to camera coordinates)
    p  = -P / P.z         (perspective division)
    p' =  f * r(p) * p    (conversion to pixel coordinates)
    r(p) = 1.0 + k1 * ||p||^2 + k2 * ||p||^4
    """
    f, k1, k2 = f_k1_k2
    q_c_w = Quaternion.from_rotation_vector(r_c_w, epsilon=None)

    # P  =  R * X + t       (conversion from world to camera coordinates)
    P_c = q_c_w.rotate(p_w) + t_c_w

    # p  = -P / P.z         (perspective division)
    p_c = -P_c[:2] / P_c[2]

    # r(p) = 1.0 + k1 * ||p||^2 + k2 * ||p||^4
    r2 = p_c[0] ** 2 + p_c[1] ** 2
    p_c_4 = r2**2
    r_p = 1.0 + k1 * r2 + k2 * p_c_4

    # p' =  f * r(p) * p    (conversion to pixel coordinates)
    obs_hat = f * r_p * p_c
    res = obs_hat - obs

    jac_r_c_w = res.jacobian(r_c_w)
    jac_t_c_w = res.jacobian(t_c_w)
    jac_f_k1_k2 = res.jacobian(f_k1_k2)
    jac_p_w = res.jacobian(p_w)
    return (
        wf.ReturnValue(res),
        wf.OutputArg(jac_r_c_w, name="jac_r_c_w", is_optional=True),
        wf.OutputArg(jac_t_c_w, name="jac_t_c_w", is_optional=True),
        wf.OutputArg(jac_f_k1_k2, name="jac_f_k1_k2", is_optional=True),
        wf.OutputArg(jac_p_w, name="jac_p_w", is_optional=True),
    )


# Generate the python code for the cost function
snavely_reprojection_error_wf, gen_code = wf.generate_python(
    snavely_reprojection_error, wf.PythonGenerator(use_output_arguments=True)
)

# You can optionally compile the generated function with numba
snavely_reprojection_error_nb = nb.njit(snavely_reprojection_error_wf)

In [None]:
class SnavelyReprojectionError(crs.CostFunction):
    def __init__(self, obs: np.ndarray):
        super().__init__()
        self.set_num_residuals(2)
        self.set_parameter_block_sizes([3, 3, 3, 3])  # r_c_w, t_c_w, f_k1_k2, p_w
        self.obs = obs.astype(np.float64)

    def Evaluate(
        self,
        parameters: list[np.ndarray],
        residuals: np.ndarray,
        jacobians: list[np.ndarray] | None,
    ) -> bool:
        jacs = [None] * self.num_parameter_blocks() if jacobians is None else jacobians
        residuals[:] = snavely_reprojection_error_nb(self.obs, *parameters, *jacs).ravel()
        return True

In [None]:
problem = crs.Problem()
loss = crs.TrivialLoss()

camera_params_opt = camera_params.copy()
point_params_opt = point_params.copy()

for i in range(num_observations):
    camera_ind = camera_inds[i]

    params = [
        camera_params_opt[camera_ind, 0:3],
        camera_params_opt[camera_ind, 3:6],
        camera_params_opt[camera_ind, 6:9],
        point_params_opt[point_inds[i]],
    ]

    cost = SnavelyReprojectionError(observations[i])
    _ = problem.add_residual_block(cost, loss, params)

    # fix first camera pose
    if i == 0:
        problem.set_parameter_block_constant(params[0])
        problem.set_parameter_block_constant(params[1])


options = crs.SolverOptions()
options.linear_solver_type = crs.LinearSolverType.SPARSE_SCHUR
options.minimizer_progress_to_stdout = True


summary = crs.SolverSummary()
crs.solve(options, problem, summary)
print(summary.BriefReport())

In [None]:
rr.init(f"{txt_file}")

world = Path("world")

rr.log(str(world), rr.ViewCoordinates.RIGHT_HAND_Y_UP)
rr.log(
    str(world),
    rr.Transform3D(translation=np.zeros(3), mat3x3=np.eye(3)),
    rr.TransformAxes3D(1.0),
    static=True,
)

rotm_gl_cam = np.array([[1, 0, 0], [0, -1, 0], [0, 0, -1]])


def rr_log_camera(params: np.ndarray, frame: Path, color: tuple[int, int, int]):
    R_cam_world = spt.Rotation.from_rotvec(params[0:3])
    t_cam_world = params[3:6]
    R_world_cam = R_cam_world.inv()
    t_world_cam = -R_world_cam.apply(t_cam_world)
    focal = params[6]

    rr.log(
        str(frame),
        rr.Transform3D(translation=t_world_cam, mat3x3=R_world_cam.as_matrix()),
        static=True,
    )
    rr.log(str(frame / "pinhole"), rr.Transform3D(mat3x3=rotm_gl_cam), static=True)
    rr.log(
        str(frame / "pinhole"),
        rr.Pinhole(
            resolution=(1600, 1200),
            focal_length=(focal, focal),
            principal_point=(800, 600),
            image_plane_distance=0.2,
            color=color,
        ),
        static=True,
    )


for i in range(num_cameras):
    cam_frame_init = world / "init" / f"camera_{i:03d}"
    rr_log_camera(camera_params[i], cam_frame_init, (255, 50, 255))

    cam_frame_opt = world / "opt" / f"camera_{i:03d}"
    rr_log_camera(camera_params_opt[i], cam_frame_opt, (50, 255, 255))

rr.log(
    str(world / "init" / "points"),
    rr.Points3D(positions=point_params, colors=(255, 0, 0)),
    static=True,
)
rr.log(
    str(world / "opt" / "points"),
    rr.Points3D(positions=point_params_opt, colors=(0, 255, 0)),
    static=True,
)

rr.notebook_show(width=1000, height=600)