In [1]:
import os
os.chdir("/dust3r")
print(os.getcwd())

/dust3r


In [2]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import open3d as o3d
import torch
import json

from dust3r.inference import inference_with_mask
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from masked_dust3r.scripts.utils.math import *
from masked_dust3r.scripts.utils.image import *


DATA_PATH = "/dust3r/masked_dust3r/data/jackal_training_data_0"
IMG_FILE_EXTENSION = ".png"
MASK_FILE_EXTENSION = ".png"
GAUSSIAN_SIGMA = 1.0
INIT_FRAMES = 50
NEW_FRAMES = 10
PREVIOUS_FRAMES = 10
TOTAL_FRAMES = 60

IS_FOCAL_FIXED = True
FOCAL_LENGTH = 4.74

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300

with open(f"{DATA_PATH}/transforms.json") as f:
    transforms = json.load(f)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# Load the model

model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

... loading model from checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
instantiating : AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), landscape_only=False)
<All keys matched successfully>


In [4]:
for start_frame_index in range(INIT_FRAMES, TOTAL_FRAMES, NEW_FRAMES):
    images_array = []
    masks_array = []

    preset_focal = [transforms["fl_x"] for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)]
    preset_pose = []
    preset_mask = [True for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)]
    preset_mask[PREVIOUS_FRAMES:] = [False for _ in range(NEW_FRAMES)]

    for i in range(len(transforms["frames"])-PREVIOUS_FRAMES, len(transforms["frames"])):
        images_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["file_path"]))
        masks_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["mask_path"]))
        preset_pose.append(np.array(transforms["frames"][i]["transform_matrix"]))
        print("Refering to {}...".format(transforms["frames"][i]["file_path"]))

    last_known_pose = preset_pose[-1]

    for i in range(start_frame_index, start_frame_index + NEW_FRAMES):
        images_array.append(os.path.join(DATA_PATH,"masked_images/{}{}".format(i,IMG_FILE_EXTENSION)))
        masks_array.append(os.path.join(DATA_PATH,"masks/{}{}".format(i,MASK_FILE_EXTENSION)))
        preset_pose.append(last_known_pose)
        print("Estimating for {}...".format(os.path.join(DATA_PATH,"masked_images/{}{}".format(i,IMG_FILE_EXTENSION))))

    images = load_images(images_array, size=512, verbose=True)
    _,_,H,W = images[0]["img"].shape
    masks = load_masks(masks_array, H, W, device)

Refering to masked_images/50.png...
Refering to masked_images/51.png...
Refering to masked_images/52.png...
Refering to masked_images/53.png...
Refering to masked_images/54.png...
Refering to masked_images/55.png...
Refering to masked_images/56.png...
Refering to masked_images/57.png...
Refering to masked_images/58.png...
Refering to masked_images/59.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/50.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/51.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/52.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/53.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/54.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/55.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/56.png...
Estimating for /d

 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/52.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/53.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/54.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/55.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/56.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/57.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/58.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/59.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_dat

In [5]:
pairs = make_pairs(images, scene_graph='swin-{}'.format(PREVIOUS_FRAMES), prefilter=None, symmetrize=True)
output = inference_with_mask(pairs, model, device, masks, GAUSSIAN_SIGMA, batch_size=batch_size)

>> Inference with model on 380 image pairs


  0%|          | 0/380 [00:00<?, ?it/s]

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 380/380 [02:36<00:00,  2.42it/s]


In [6]:
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PlanePointCloudOptimizer, 
                       weight_focal = 1, 
                       weight_z = 0.1, 
                       weight_rot = 0.1, 
                       weight_trans_smoothness = 0.001,
                       weight_rot_smoothness = 0.001)
scene.preset_focal(preset_focal, [True for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)])
scene.preset_pose(preset_pose, preset_mask)

 (setting focal #0 = 469.3912353515625)
 (setting focal #1 = 469.3912353515625)
 (setting focal #2 = 469.3912353515625)
 (setting focal #3 = 469.3912353515625)
 (setting focal #4 = 469.3912353515625)
 (setting focal #5 = 469.3912353515625)
 (setting focal #6 = 469.3912353515625)
 (setting focal #7 = 469.3912353515625)
 (setting focal #8 = 469.3912353515625)
 (setting focal #9 = 469.3912353515625)
 (setting focal #10 = 469.3912353515625)
 (setting focal #11 = 469.3912353515625)
 (setting focal #12 = 469.3912353515625)
 (setting focal #13 = 469.3912353515625)
 (setting focal #14 = 469.3912353515625)
 (setting focal #15 = 469.3912353515625)
 (setting focal #16 = 469.3912353515625)
 (setting focal #17 = 469.3912353515625)
 (setting focal #18 = 469.3912353515625)
 (setting focal #19 = 469.3912353515625)
 (setting pose #0 = [ 0.08369204 -0.23462954  0.15426631])
 (setting pose #1 = [ 0.06546892 -0.22099422  0.15435892])
 (setting pose #2 = [ 0.06222615 -0.21649924  0.15444268])
 (setting pos

In [7]:
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

 init edge (15*,18*) score=1.2693949937820435
 init edge (5*,18) score=1.2693949937820435
 init edge (5,8*) score=1.2693949937820435
 init edge (12*,18) score=1.2579251527786255
 init edge (2*,8) score=1.2579251527786255
 init edge (6*,18) score=1.2536730766296387
 init edge (14*,18) score=1.2412238121032715
 init edge (4*,18) score=1.2412238121032715
 init edge (6,19*) score=1.258823275566101
 init edge (6,9*) score=1.258823275566101
 init edge (3*,6) score=1.2579948902130127
 init edge (17*,19) score=1.2557786703109741
 init edge (7*,19) score=1.2557786703109741
 init edge (7,11*) score=1.214026689529419
 init edge (7,1*) score=1.214026689529419
 init edge (3,10*) score=1.1488021612167358
 init edge (3,0*) score=1.1488021612167358
 init edge (16*,19) score=1.258823275566101
 init edge (13*,16) score=1.2579948902130127
 init loss = 0.07411404699087143
Global alignement - optimizing for:
['pw_poses', 'im_depthmaps.0', 'im_depthmaps.1', 'im_depthmaps.2', 'im_depthmaps.3', 'im_depthmaps.

100%|██████████| 300/300 [05:18<00:00,  1.06s/it, lr=1.27413e-06 loss=0.00091122] 


In [14]:
print(confidence_masks[PREVIOUS_FRAMES])

tensor([[False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False]], device='cuda:0')


In [15]:
imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()

for i in range(PREVIOUS_FRAMES, PREVIOUS_FRAMES+NEW_FRAMES):
    new_frame = {
        "file_path" : "/".join(images_array[i].split("/")[-2:]),
        "transform_matrix" : poses[i].tolist(),
        "mask_path" : "/".join(masks_array[i].split("/")[-2:])
    }
    if confidence_masks[i].sum() > 0:
        transforms["frames"].append(new_frame)
    else:
        print("Reject frame {} due to low confidence".format(i))

with open(f"{DATA_PATH}/transforms.json", "w") as f:
    json.dump(transforms, f, indent=4)

{'file_path': 'masked_images/50.png', 'transform_matrix': [[-0.730869710445404, -0.5621897578239441, 0.38700374960899353, 0.06657484173774719], [0.6821752190589905, -0.6196547150611877, 0.38815563917160034, -0.22021174430847168], [0.02159157395362854, 0.5476956367492676, 0.8363988995552063, 0.1544608473777771], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/50.png'}
{'file_path': 'masked_images/51.png', 'transform_matrix': [[-0.715075671672821, -0.5763584971427917, 0.39557260274887085, 0.06622093915939331], [0.6987153887748718, -0.60672527551651, 0.37905314564704895, -0.21998773515224457], [0.021533414721488953, 0.5474443435668945, 0.8365650177001953, 0.15437352657318115], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/51.png'}
{'file_path': 'masked_images/52.png', 'transform_matrix': [[-0.6411790251731873, -0.6347156763076782, 0.4313070476055145, 0.06388243287801743], [0.7671033143997192, -0.545526921749115, 0.3375694155693054, -0.2170315980911255], [0.021028965711593628, 0.5472994446754456, 