In [1]:
#!/bin/python3 python3.10

import os
#Set directory to dust3r
os.chdir("/dust3r")
print(os.getcwd())

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import open3d as o3d
import torch
import json
from masked_dust3r.scripts.utils.image import *
from masked_dust3r.scripts.utils.constraint import *

from dust3r.inference import inference_with_mask
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode

DATA_PATH = "/dust3r/masked_dust3r/data/jackal_training_data_0"
IMG_FILE_EXTENSION = ".png"
MASK_FILE_EXTENSION = ".png"
GAUSSIAN_SIGMA = 3.0
INIT_FRAMES = 10
RECURRING_FRAMES = 5
TOTAL_IMGS = 11

IS_FOCAL_FIXED = False
IS_BEST_FIT_PLANE = True
IS_ZERO_Z = True
FOCAL_LENGTH = 4.74

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300

with open(f"{DATA_PATH}/transforms.json") as f:
    transforms = json.load(f)

/dust3r


Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
# Load the model

model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

... loading model from checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth


instantiating : AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), landscape_only=False)
<All keys matched successfully>


In [3]:
new_img_index = 10
print("Looking at frame {}...".format(new_img_index))
images_array = []
masks_array = []

preset_focal = [transforms["fl_x"] for _ in range(RECURRING_FRAMES+1)]
preset_pose = []
preset_mask = [True for _ in range(RECURRING_FRAMES+1)]
preset_mask[0] = False

images_array.append(os.path.join(DATA_PATH,"masked_images/{}{}".format(new_img_index,IMG_FILE_EXTENSION)))
masks_array.append(os.path.join(DATA_PATH,"masks/{}{}".format(new_img_index,MASK_FILE_EXTENSION)))
preset_pose.append(np.eye(4))

for i in range(-RECURRING_FRAMES,0):
    images_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["file_path"]))
    masks_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["mask_path"]))
    preset_pose.append(np.array(transforms["frames"][i]["transform_matrix"]))
    print("Using {}...".format(transforms["frames"][i]["file_path"]))
preset_pose[0] = preset_pose[-1]

images = load_images(images_array, size=512, verbose=True)
_,_,H,W = images[0]["img"].shape
masks = load_masks(masks_array, H, W, device)

Looking at frame 10...
Using masked_images/5.png...
Using masked_images/6.png...
Using masked_images/7.png...
Using masked_images/8.png...
Using masked_images/9.png...
>> Loading a list of 6 images
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/10.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/5.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/6.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/7.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/8.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/9.png with resolution 1280x720 --> 512x288
 (Found 6 images)


In [4]:
pairs = make_pairs(images, scene_graph='oneref-0', prefilter=None, symmetrize=True)
output = inference_with_mask(pairs, model, device, masks, GAUSSIAN_SIGMA, batch_size=batch_size)

>> Inference with model on 10 image pairs


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 10/10 [00:06<00:00,  1.62it/s]


In [5]:
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.ModularPointCloudOptimizer)
scene.preset_focal(preset_focal, [True for _ in range(RECURRING_FRAMES+1)])
scene.preset_pose(preset_pose, preset_mask)

 (setting focal #0 = 491.58966064453125)
 (setting focal #1 = 491.58966064453125)
 (setting focal #2 = 491.58966064453125)
 (setting focal #3 = 491.58966064453125)
 (setting focal #4 = 491.58966064453125)
 (setting focal #5 = 491.58966064453125)
 (setting pose #1 = [ 0.10944864 -0.13494991  0.        ])
 (setting pose #2 = [ 0.04718999 -0.04209871  0.        ])
 (setting pose #3 = [ 0.08555276 -0.0451539   0.        ])
 (setting pose #4 = [ 0.10870501 -0.04764657  0.        ])
 (setting pose #5 = [ 0.1141178  -0.09485686  0.        ])


In [6]:
im_pose =  scene.get_im_poses()
print(im_pose)


tensor([[[-0.9540,  0.1001, -0.2825, -2.9793],
         [-0.2448,  0.2837,  0.9272, 12.5863],
         [ 0.1729,  0.9537, -0.2462, -2.1108],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.7140,  0.4959, -0.4943,  0.1094],
         [-0.6461,  0.1945, -0.7381, -0.1349],
         [-0.2698,  0.8463,  0.4593,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.9659,  0.1422, -0.2164,  0.0472],
         [-0.2536,  0.3491, -0.9021, -0.0421],
         [-0.0528,  0.9262,  0.3732,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.9151,  0.2233, -0.3357,  0.0856],
         [-0.3906,  0.2842, -0.8756, -0.0452],
         [-0.1001,  0.9324,  0.3473,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.8687,  0.3050, -0.3903,  0.1087],
         [-0.4735,  0.2801, -0.8351, -0.0476],
         [-0.1454,  0.9103,  0.3877,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.7919,  0.4062, -0.4559,  0.1141],
   

In [7]:
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

 init edge (2*,0*) score=1.1480892896652222
 init edge (1*,0) score=1.1418921947479248


 init edge (3*,0) score=1.1078542470932007
 init edge (0,5*) score=1.0900238752365112
 init edge (4*,0) score=1.088693380355835
 init loss = 0.004925969522446394
Global alignement - optimizing for:
['pw_poses', 'im_depthmaps.0', 'im_depthmaps.1', 'im_depthmaps.2', 'im_depthmaps.3', 'im_depthmaps.4', 'im_depthmaps.5', 'im_poses.0']


100%|██████████| 300/300 [00:20<00:00, 14.52it/s, lr=1.27413e-06 loss=0.000385978]


In [8]:
im_pose =  scene.get_im_poses()
print(im_pose)

tensor([[[ 0.9362,  0.2554,  0.2413,  0.0037],
         [ 0.1449,  0.3448, -0.9274, -0.0990],
         [-0.3201,  0.9033,  0.2858,  0.0267],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.7140,  0.4959, -0.4943,  0.1094],
         [-0.6461,  0.1945, -0.7381, -0.1349],
         [-0.2698,  0.8463,  0.4593,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.9659,  0.1422, -0.2164,  0.0472],
         [-0.2536,  0.3491, -0.9021, -0.0421],
         [-0.0528,  0.9262,  0.3732,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.9151,  0.2233, -0.3357,  0.0856],
         [-0.3906,  0.2842, -0.8756, -0.0452],
         [-0.1001,  0.9324,  0.3473,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.8687,  0.3050, -0.3903,  0.1087],
         [-0.4735,  0.2801, -0.8351, -0.0476],
         [-0.1454,  0.9103,  0.3877,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  1.0000]],

        [[ 0.7919,  0.4062, -0.4559,  0.1141],
   

In [9]:

imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()

if (confidence_masks[0]!=0).all():
    print("No confidence in Frame {}".format(new_img_index))       
    pass

new_tf = poses[0].detach().cpu().numpy().tolist()
if abs(new_tf[2][3]) > 0.1:
    pass
new_tf[2][3] = 0

new_frame = {
    "file_path" : "/".join(images_array[0].split("/")[-2:]),
    "transform_matrix" : new_tf,
    "mask_path" : "/".join(masks_array[0].split("/")[-2:])
}
#transforms["frames"].append(new_frame)
