In [8]:
import os
os.chdir("/dust3r")
print(os.getcwd())

/dust3r


In [9]:
import json
import roma
import torch
import numpy as np
from scipy.spatial.transform import Rotation 

DATA_PATH = "/dust3r/masked_dust3r/data/jackal_irl_one_spin"
FILE_1 = os.path.join(DATA_PATH, "transforms_1.json")
FILE_2 = os.path.join(DATA_PATH, "transforms_2.json")

with open(FILE_1, "r") as f:
    data_1 = json.load(f)
with open(FILE_2, "r") as f:
    data_2 = json.load(f)

In [10]:
file_path_1 = [f["file_path"] for f in data_1["frames"]]
file_path_2 = [f["file_path"] for f in data_2["frames"]]

# Find matching keys in data_1 and data_2
matching_keys = []
for i, f_1 in enumerate(file_path_1):
    for j, f_2 in enumerate(file_path_2):
        if f_1 == f_2:
            matching_keys.append((i, j))
            break
print(matching_keys)

X = []
Y = []

for (i,j) in matching_keys:
    X.append(data_1["frames"][i]["transform_matrix"])
    Y.append(data_2["frames"][j]["transform_matrix"])

X = torch.tensor(X)
Y = torch.tensor(Y)

X = X[:,:3,3]
Y = Y[:,:3,3]

X = X.unsqueeze(0)
Y = Y.unsqueeze(0)

(R,t,s) = roma.rigid_points_registration(Y,X,compute_scaling=True)

corr = torch.tensor(np.zeros((4,4)))
corr[:3,:3] = R[0]
corr[:3,3] = t[0]/s[0]
corr[3,3] = 1
corr = corr.numpy()

print(corr)

[(10, 0), (11, 1), (12, 2), (13, 3), (14, 4), (15, 5), (16, 6), (17, 7), (18, 8), (19, 9)]
[[ 0.15529452  0.90177882 -0.40333468  1.20198894]
 [-0.16230246 -0.3794491  -0.91086584 -0.12750249]
 [-0.97444439  0.20691469  0.08743466  0.49195433]
 [ 0.          0.          0.          1.        ]]


In [11]:
all_frames = {}

for f in data_1["frames"]:
    all_frames[f["file_path"]] = f

for f in data_2["frames"]:
    transform_f = f.copy()
    transform_f["transform_matrix"] = np.dot(corr, transform_f["transform_matrix"]).tolist()
    all_frames[f["file_path"]] = transform_f

print(all_frames)

{'masked_images/0.png': {'file_path': 'masked_images/0.png', 'transform_matrix': [[0.804954879347945, 0.39127038052422497, 0.44604377872431766, 0.10346801014166823], [0.3985746213061731, -0.913462663740458, 0.08200155822916588, -0.004094556099282245], [0.43952911834941294, 0.11177417107492905, -0.8912466559721222, 0.06137701512133444], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/0.png', 'fl_x': 696.2276458740234, 'fl_y': 696.2276458740234, 'cx': 320.0, 'cy': 240.0}, 'masked_images/1.png': {'file_path': 'masked_images/1.png', 'transform_matrix': [[0.7933903481791179, 0.39043935740413155, 0.4670000038614324, 0.11214370701638388], [0.3805610579564126, -0.9169281830520767, 0.12006720087397982, 0.006230642941854678], [0.47508442296961667, 0.08246185477558027, -0.8760678814590015, 0.06066276601744281], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/1.png', 'fl_x': 745.8506011962891, 'fl_y': 745.8506011962891, 'cx': 320.0, 'cy': 240.0}, 'masked_images/2.png': {'file_path': 'masked_images/2.png', '

In [12]:
#Write out transforms

transforms = data_1.copy()
transforms["frames"] = list(all_frames.values())
print(transforms)

with open(os.path.join(DATA_PATH, "transforms.json"), "w") as f:
    json.dump(transforms, f, indent=4)

{'camera_model': 'OPENCV', 'frames': [{'file_path': 'masked_images/0.png', 'transform_matrix': [[0.804954879347945, 0.39127038052422497, 0.44604377872431766, 0.10346801014166823], [0.3985746213061731, -0.913462663740458, 0.08200155822916588, -0.004094556099282245], [0.43952911834941294, 0.11177417107492905, -0.8912466559721222, 0.06137701512133444], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/0.png', 'fl_x': 696.2276458740234, 'fl_y': 696.2276458740234, 'cx': 320.0, 'cy': 240.0}, {'file_path': 'masked_images/1.png', 'transform_matrix': [[0.7933903481791179, 0.39043935740413155, 0.4670000038614324, 0.11214370701638388], [0.3805610579564126, -0.9169281830520767, 0.12006720087397982, 0.006230642941854678], [0.47508442296961667, 0.08246185477558027, -0.8760678814590015, 0.06066276601744281], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/1.png', 'fl_x': 745.8506011962891, 'fl_y': 745.8506011962891, 'cx': 320.0, 'cy': 240.0}, {'file_path': 'masked_images/2.png', 'transform_matrix': [[0.986265584