In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import numpy as np
import torch
import xml.etree.ElementTree as ET
import sqlite3
import internal.utils.colmap as colmap

In [None]:
basic_path = os.path.expanduser("~/data/image_set/dbl/")
image_dir_relative = "AerialPhotography"
def image_path_to_name(image_path):
    return image_path.split(":")[1][1:]

# Open XML
Looks like the new version has renamed to iTwin

In [None]:
# `fixed_pp` means the principle points are not adjusted
tree = ET.parse(os.path.expanduser(os.path.join(basic_path, "Smart3DExportedPoses-ENU-x_right-y_down-fixed_pp.xml")))
tree

In [None]:
root = tree.getroot()
root

In [None]:
block = root.find("Block")
block

# Parse XML

In [None]:
cameras = []
image_paths = []
image_names = []
poses = []
centers = []
image_camera_ids = []
for photogroup in block.findall("Photogroups/Photogroup"):
    imageDimensions = photogroup.find("ImageDimensions")
    width = int(imageDimensions.find("Width").text)
    height = int(imageDimensions.find("Height").text)
    focal_length = float(photogroup.find("FocalLength").text)
    sensor_size = float(photogroup.find("SensorSize").text)

    focal_length_in_pixel = focal_length / sensor_size * width

    cameras.append({
        "width": width,
        "height": height,
        "focal_length": focal_length_in_pixel,
        "principal_point": (
            float(photogroup.find("PrincipalPoint/x").text), float(photogroup.find("PrincipalPoint/y").text)),
        "distortion": {i.tag: float(i.text) for i in photogroup.find("Distortion")}
    })
    camera_idx = len(cameras) - 1

    for photo in photogroup.findall("Photo"):
        rotation = list(photo.find("Pose/Rotation"))
        center = list(photo.find("Pose/Center"))
        if rotation[-1].text == "false":
            continue
        if center[-1].text == "false":
            continue
        image_paths.append(photo.find("ImagePath").text)
        image_names.append(image_path_to_name(photo.find("ImagePath").text))
        poses.append([float(i.text) for i in rotation[:-1]])
        centers.append([float(i.text) for i in center[:-1]])
        image_camera_ids.append(camera_idx)
cameras

In [None]:
poses[0], centers[0], image_paths[0]

# Convert

In [None]:
pose_reshaped = torch.tensor(poses, dtype=torch.float64).reshape((-1, 3, 3))
pose_reshaped.shape, pose_reshaped[0]

In [None]:
c2w_rotations = torch.transpose(pose_reshaped, 1, 2)
(c2w_rotations[0] == pose_reshaped[0].T).all(), c2w_rotations[0]

In [None]:
c2w = torch.concat([
    torch.concat([c2w_rotations, torch.tensor(centers, dtype=torch.float64)[..., None]], dim=-1),
    torch.tensor([0., 0., 0., 1.], dtype=torch.float64)[None, None, :].repeat(c2w_rotations.shape[0], 1, 1),
], dim=1)
c2w.shape, c2w[1], centers[1]

In [None]:
camera_centers = c2w[:, :3, 3]
camera_centers[0], centers[0]

# Rescale and Translation

In [None]:
mean_center = torch.mean(camera_centers, dim=0)
mean_center

In [None]:
camera_center_min = torch.min(camera_centers, dim=0).values
camera_center_max = torch.max(camera_centers, dim=0).values
camera_center_range = camera_center_max - camera_center_min
camera_center_range

In [None]:
mid_center = (camera_center_min + camera_center_max) * 0.5
mid_center

In [None]:
max_range = 100.
scale = camera_center_range.max() / max_range
scale

In [None]:
c2w_rescaled_and_moved = torch.clone(c2w)
c2w_rescaled_and_moved[:, :3, 3] -= mid_center
c2w_rescaled_and_moved[:, :3, 3] /= scale
c2w[0], c2w_rescaled_and_moved[0]

In [None]:
torch.save({
    "image_names": image_names,
    "cameras": cameras,
    "c2w": c2w,
    "image_camera_ids": image_camera_ids,
    "center": mid_center,
    "scale": scale,
}, os.path.join(basic_path, "parsed_from_xml.pt"))

# Select a few for preview in NeRFStudio

In [None]:
distance2center = torch.norm(camera_centers - mean_center[None, :], dim=-1)
distance2center

In [None]:
select_mask = distance2center < 128.
select_mask.sum()

In [None]:
selected_image_ids = select_mask.nonzero().squeeze(-1)
selected_image_ids

In [None]:
camera_list = []
for idx, pose in enumerate(c2w[select_mask]):
    camera_list.append({
        "id": idx,
        "img_name": "{:06d}".format(idx),
        "width": 1920,
        "height": 1080,
        "position": (pose[:3, 3] * 0.01).tolist(),
        "rotation": pose[:3, :3].tolist(),
        "fx": 1600,
        "fy": 1600,
        "color": [255, 0, 0],
    })
with open(os.path.join(os.path.expanduser("~/data/image_set/dbl"), "preview.json"), "w") as f:
    json.dump(camera_list, f)
os.path.join(os.path.expanduser("~/data/image_set/dbl"), "preview.json")

In [None]:
transforms = {
    "aabb_scale": 16,
}

frames = []
for idx in selected_image_ids.tolist():
    camera_id = image_camera_ids[idx]
    file_path = os.path.join(image_dir_relative, image_names[idx])
    camera = cameras[camera_id]
    transform_matrix = torch.clone(c2w[idx])
    transform_matrix[:, 1:3] *= -1
    frames.append({
        "file_path": file_path,
        "camera_model": "OPENCV",
        "fl_x": camera["focal_length"],
        "fl_y": camera["focal_length"],
        "k1": camera["distortion"]["K1"],
        "k2": camera["distortion"]["K2"],
        "p1": camera["distortion"]["P1"],
        "p2": camera["distortion"]["P2"],
        "cx": camera["width"] // 2,
        "cy": camera["height"] // 2,
        "w": camera["width"],
        "h": camera["height"],
        "transform_matrix": transform_matrix.tolist(),
    })

transforms["frames"] = frames

transforms_json_path = os.path.join(os.path.expanduser("~/data/image_set/dbl"), "transforms.json")
with open(transforms_json_path, "w") as f:
    json.dump(transforms, f, indent=2)
transforms_json_path

# Colmap

In [None]:
colmap_output_path = os.path.join(basic_path, "colmap")
colmap_image_path = os.path.join(basic_path, image_dir_relative)

extract features

In [None]:
colmap_db_path = os.path.join(colmap_output_path, "colmap.db")
assert os.path.exists(colmap_db_path) is False
print(" \\\n    ".join([
    "colmap",
    "feature_extractor",
    "--database_path=" + colmap_db_path,
    "--image_path=" + colmap_image_path,
    "--ImageReader.camera_model=OPENCV",
]))

create a sparse model from known poses

In [None]:
sparse_manually_model_dir = os.path.join(colmap_output_path, "sparse_manually")

In [None]:
assert os.path.exists(sparse_manually_model_dir) is False
assert os.path.exists(colmap_db_path + "-shm") is False, "{} is opened by another process".format(colmap_db_path)

colmap_db = sqlite3.connect(colmap_db_path)

def array_to_blob(array):
    return array.tostring()


def select_image(image_name: str):
    cur = colmap_db.cursor()
    try:
        return cur.execute("SELECT image_id, camera_id FROM images WHERE name = ?", [image_name]).fetchone()
    finally:
        cur.close()


def set_image_camera_id(image_id: int, camera_id: int):
    cur = colmap_db.cursor()
    try:
        cur.execute("UPDATE images SET camera_id = ? WHERE image_id = ?", [camera_id, image_id])
        colmap_db.commit()
    finally:
        cur.close()


def update_camera_params(camera_id: int, params: np.ndarray):
    cur = colmap_db.cursor()
    try:
        cur.execute("UPDATE cameras SET params = ? WHERE camera_id = ?", [
            array_to_blob(params),
            camera_id,
        ])
        colmap_db.commit()
    finally:
        cur.close()


def delete_unused_cameras():
    cur = colmap_db.cursor()
    try:
        cur.execute("DELETE FROM cameras WHERE camera_id NOT IN (SELECT camera_id FROM images)")
        colmap_db.commit()
    finally:
        cur.close()

In [None]:
w2cs = torch.linalg.inv(c2w_rescaled_and_moved)
w2cs[0], c2w_rescaled_and_moved[0]

In [None]:
colmap_cameras = {}
colmap_images = {}
context_camera_idx_to_colmap_camera_idx = {}

for idx in range(c2w.shape[0]):
    image_name = image_names[idx]
    colmap_image_idx, colmap_camera_idx = select_image(image_name)
    # share intrinsics
    context_camera_idx = image_camera_ids[idx]
    colmap_camera_idx = context_camera_idx_to_colmap_camera_idx.setdefault(context_camera_idx, colmap_camera_idx)
    set_image_camera_id(colmap_image_idx, colmap_camera_idx)

    w2c = w2cs[idx]

    colmap_images[colmap_image_idx] = colmap.Image(
        id=colmap_image_idx,
        qvec=colmap.rotmat2qvec(w2c[:3, :3].numpy()),
        tvec=w2c[:3, 3].numpy(),
        camera_id=colmap_camera_idx,
        name=image_name,
        xys=np.array([], dtype=np.float64),
        point3D_ids=np.asarray([], dtype=np.int64),
    )

    if colmap_camera_idx not in colmap_cameras:
        camera = cameras[context_camera_idx]
        # [fx, fy, cx, cy, k1, k2, p1, p2]
        camera_params = torch.tensor([
            camera["focal_length"],
            camera["focal_length"],
            camera["width"] // 2,
            camera["height"] // 2,
            camera["distortion"]["K1"],
            camera["distortion"]["K2"],
            camera["distortion"]["P1"],
            camera["distortion"]["P2"],
        ], dtype=torch.float64)
        update_camera_params(colmap_camera_idx, camera_params.numpy())
        colmap_cameras[colmap_camera_idx] = colmap.Camera(
            id=colmap_camera_idx,
            model="OPENCV",
            width=camera["width"],
            height=camera["height"],
            params=camera_params.numpy(),
        )

delete_unused_cameras()
colmap_db.close()

In [None]:
os.makedirs(sparse_manually_model_dir)
colmap.write_images_binary(colmap_images, os.path.join(sparse_manually_model_dir, "images.bin"))
colmap.write_cameras_binary(colmap_cameras, os.path.join(sparse_manually_model_dir, "cameras.bin"))
colmap.write_points3D_binary({}, os.path.join(sparse_manually_model_dir, "points3D.bin"))

In [None]:
colmap.read_cameras_binary(os.path.join(sparse_manually_model_dir, "cameras.bin"))

feature matcher

In [None]:
print(" \\\n    ".join([
    "colmap",
    "vocab_tree_matcher",
    "--database_path=" + colmap_db_path,
    "--VocabTreeMatching.vocab_tree_path=" + os.path.expanduser("~/.cache/colmap/vocab_tree_flickr100K_words256K.bin"),
]))

point triangulator

In [None]:
sparse_dir_triangulated = os.path.join(colmap_output_path, "sparse")
os.makedirs(sparse_dir_triangulated, exist_ok=True)
print(" \\\n    ".join([
        "colmap",
        "point_triangulator",
        "--database_path=" + colmap_db_path,
        "--image_path=" + colmap_image_path,
        "--input_path=" + sparse_manually_model_dir,
        "--output_path=" + sparse_dir_triangulated,
]))