In [1]:
import os
os.chdir("/dust3r")
print(os.getcwd())

/dust3r


In [2]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import json
import open3d as o3d

from dust3r.inference import inference_with_mask
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.cloud_opt.base_opt import global_alignment_loop
from masked_dust3r.scripts.utils.math import *
from masked_dust3r.scripts.utils.image import *


DATA_PATH = "/dust3r/masked_dust3r/data/jackal_training_data_0"
IMG_FILE_EXTENSION = ".png"
MASK_FILE_EXTENSION = ".png"
GAUSSIAN_SIGMA = 31.0 #TODO: Dynamically size GAUSSIAN_SIGMA based on mask size?
INIT_FRAMES = 50
NEW_FRAMES = 10
PREVIOUS_FRAMES = 40
TOTAL_FRAMES = 300

INIT_WEIGHT_FOCAL = 0
INIT_WEIGHT_Z = 0.01 
INIT_WEIGHT_ROT = 0
INIT_WEIGHT_TRANS_SMOOTHNESS = 0.0001 * 0
INIT_WEIGHT_ROT_SMOOTHNESS = 0.001 * 0

NEW_WEIGHT_FOCAL = 0.1
NEW_WEIGHT_Z = 0.1
NEW_WEIGHT_ROT = 0.1
NEW_WEIGHT_TRANS_SMOOTHNESS = 0.00001
NEW_WEIGHT_ROT_SMOOTHNESS = 0.00001

IS_FOCAL_FIXED = True
FOCAL_LENGTH = 4.74

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
# Load the model

model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

... loading model from checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth


instantiating : AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), landscape_only=False)
<All keys matched successfully>


In [4]:
images_array = []
masks_array = []

for i in range(0,250,2):
    images_array.append(os.path.join(DATA_PATH,"masked_images/{}{}".format(i,IMG_FILE_EXTENSION)))
    masks_array.append(os.path.join(DATA_PATH,"masks/{}{}".format(i,MASK_FILE_EXTENSION)))
images = load_images(images_array, size=512, verbose=True)
_,_,H,W = images[0]["img"].shape
masks = load_masks(masks_array, H, W, device)

>> Loading a list of 125 images
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/0.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/2.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/4.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/6.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/8.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/10.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/12.png with resolution 1280x720 --> 512x288


 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/14.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/16.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/18.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/20.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/22.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/24.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/26.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/28.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_dat

In [5]:
pairs = make_pairs(images, scene_graph='swin-2', prefilter=None, symmetrize=True)
output = inference_with_mask(pairs, model, device, masks, GAUSSIAN_SIGMA, batch_size=batch_size)
del pairs

>> Inference with model on 500 image pairs


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
100%|██████████| 500/500 [02:45<00:00,  3.02it/s]


In [6]:
#init_scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PlanePointCloudOptimizer)
#loss = init_scene.compute_global_alignment(init="mst", niter=niter, schedule='cosine', lr=lr)

#scene = init_scene

scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PlanePointCloudOptimizer, 
                        weight_focal = INIT_WEIGHT_FOCAL,
                        weight_z = INIT_WEIGHT_Z ,
                        weight_rot = INIT_WEIGHT_ROT  ,
                        weight_trans_smoothness = INIT_WEIGHT_TRANS_SMOOTHNESS,
                        weight_rot_smoothness = INIT_WEIGHT_ROT_SMOOTHNESS)
#scene.im_poses = calculate_new_params(init_scene.im_poses,device)
#scene.im_focals = init_scene.im_focals
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)

#scene.weight_focal = INIT_WEIGHT_FOCAL
#scene.weight_z = INIT_WEIGHT_Z 
#scene.weight_rot = INIT_WEIGHT_ROT
#scene.weight_trans_smoothness = INIT_WEIGHT_TRANS_SMOOTHNESS 
#scene.weight_rot_smoothness = INIT_WEIGHT_ROT_SMOOTHNESS
#loss = global_alignment_loop(scene, niter=niter, schedule=schedule, lr=lr)


 init edge (84*,85*) score=1.7405554056167603
 init edge (84,86*) score=1.5887047052383423
 init edge (85,87*) score=1.5216656923294067
 init edge (84,83*) score=1.4369206428527832
 init edge (84,82*) score=1.3797416687011719
 init edge (87,89*) score=1.298582673072815
 init edge (80*,82) score=1.1279969215393066
 init edge (83,81*) score=1.1100265979766846
 init edge (90*,89) score=1.5650523900985718
 init edge (90,92*) score=1.489304542541504
 init edge (91*,90) score=1.4415972232818604
 init edge (78*,80) score=1.415124773979187
 init edge (90,88*) score=1.3684473037719727
 init edge (91,93*) score=1.27093505859375
 init edge (92,94*) score=1.1703802347183228
 init edge (96*,94) score=1.1592035293579102
 init edge (78,77*) score=1.6100058555603027
 init edge (78,79*) score=1.4876244068145752
 init edge (78,76*) score=1.4674479961395264
 init edge (97*,96) score=1.452447533607483
 init edge (97,99*) score=1.4073532819747925
 init edge (97,95*) score=1.309325933456421
 init edge (74*,

100%|██████████| 300/300 [06:07<00:00,  1.23s/it, lr=1.27413e-06 loss=0.00127661]


In [7]:
imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()

In [8]:
#Check if pointclouds folder exists
#If exists, delete all files in the folder
if os.path.exists("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
    for file in os.listdir("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
        os.remove("{DATA_PATH}/pointclouds/{file}".format(DATA_PATH=DATA_PATH, file=file))
        
if not os.path.exists("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
    os.makedirs("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH))

for i in range(len(images)):
    pointcloud = pts3d[i].detach().cpu().numpy()
    pointcloud = pointcloud.reshape(-1, 3)
    color = imgs[i].reshape(-1, 3)
    confidence_mask = confidence_masks[i].detach().cpu().numpy()
    confidence_mask = confidence_mask.reshape(-1)
    
    masked_pointcloud = []
    masked_color = []

    for j in range(len(confidence_mask)):
        if confidence_mask[j]:
            masked_pointcloud.append(pointcloud[j])
            masked_color.append(color[j])

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(masked_pointcloud)
    pcd.colors = o3d.utility.Vector3dVector(masked_color)
    o3d.io.write_point_cloud("{DATA_PATH}/pointclouds/pointcloud{i}.ply".format(DATA_PATH=DATA_PATH, i=i), pcd)



In [9]:
#Create transform file
#TODO: Per frame camera model?
transforms = {}
transforms["camera_model"] = "OPENCV"

averge_focal = focals.sum()/len(focals)
transforms["fl_x"] = averge_focal.item()
transforms["fl_y"] = averge_focal.item()

#Find size of images
img = Image.open(images_array[0])
width, height = img.size
transforms["w"] = width
transforms["h"] = height
transforms["cx"] = width//2
transforms["cy"] = height//2

transforms["frames"] = []

OPENGL = np.array([[1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, 0],
                    [0, 0, 0, 1]])

for i in range(len(poses)):
    if not((confidence_masks[i]==0).all()):
        frame = {}
        frame["file_path"] = "/".join(images_array[i].split("/")[-2:])
        #frame["transform_matrix"] = np.linalg.inv(poses[i].detach().cpu().numpy()).tolist()
        frame["transform_matrix"] = np.dot(poses[i].detach().cpu().numpy(),OPENGL).tolist()
        frame["mask_path"] = "/".join(masks_array[i].split("/")[-2:])
        transforms["frames"].append(frame)

#Save transform file
with open("{}/transforms.json".format(DATA_PATH), 'w') as f:
    json.dump(transforms, f, indent=4)

In [10]:
print(transforms["frames"])

[{'file_path': 'masked_images/0.png', 'transform_matrix': [[0.4565823972225189, 0.7333185076713562, -0.5037623643875122, -0.20346814393997192], [0.6956750750541687, -0.6472344398498535, -0.3116469085216522, -0.07301289588212967], [-0.554588794708252, -0.20816239714622498, -0.8056672811508179, 0.00013973591558169574], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/0.png'}, {'file_path': 'masked_images/2.png', 'transform_matrix': [[0.6376181840896606, 0.5709162354469299, -0.5172017812728882, -0.1934170424938202], [0.5692303776741028, -0.8015462160110474, -0.18303130567073822, -0.019342171028256416], [-0.5190566778182983, -0.1777028888463974, -0.8360633254051208, 6.780733383493498e-05], [0.0, 0.0, 0.0, 1.0]], 'mask_path': 'masks/2.png'}, {'file_path': 'masked_images/4.png', 'transform_matrix': [[0.7719601988792419, 0.39762797951698303, -0.49595317244529724, -0.1717391312122345], [0.43068215250968933, -0.9010038375854492, -0.05201065540313721, 0.030180320143699646], [-0.4675365388393402, -0.17

In [11]:
import roma

all_poses = torch.stack(list(scene.im_poses))
Q = all_poses[:,:4]
Q = torch.nn.functional.normalize(Q, p=2, dim=1)
T = signed_expm1(all_poses[:,4:7])
tf = roma.RigidUnitQuat(Q, T).normalize().to_homogeneous()

OPENGL = torch.tensor([[1, 0, 0, 0],
                       [0, -1, 0, 0],
                       [0, 0, -1, 0],
                       [0, 0, 0, 1]], dtype=torch.float32).to(device)

tf = torch.matmul(tf, OPENGL)

print(tf)

tensor([[[ 4.5658e-01,  7.3332e-01, -5.0376e-01, -2.0347e-01],
         [ 6.9568e-01, -6.4723e-01, -3.1165e-01, -7.3013e-02],
         [-5.5459e-01, -2.0816e-01, -8.0567e-01,  1.3974e-04],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[ 6.3762e-01,  5.7092e-01, -5.1720e-01, -1.9342e-01],
         [ 5.6923e-01, -8.0155e-01, -1.8303e-01, -1.9342e-02],
         [-5.1906e-01, -1.7770e-01, -8.3606e-01,  6.7807e-05],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        [[ 7.7196e-01,  3.9763e-01, -4.9595e-01, -1.7174e-01],
         [ 4.3068e-01, -9.0100e-01, -5.2011e-02,  3.0180e-02],
         [-4.6754e-01, -1.7345e-01, -8.6679e-01,  1.3534e-03],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]],

        ...,

        [[-2.3567e-01,  8.2347e-01, -5.1610e-01, -1.1108e-01],
         [ 7.6463e-01, -1.7066e-01, -6.2146e-01, -1.7436e-01],
         [-5.9983e-01, -5.4108e-01, -5.8943e-01,  9.4025e-02],
         [ 0.0000e+00,  0.0000e+00,