In [1]:
import os
os.chdir("/dust3r")
print(os.getcwd())

/dust3r


In [2]:
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import torch
import json
import open3d as o3d

from dust3r.inference import inference_with_mask, create_gaussian_kernel
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.cloud_opt.base_opt import global_alignment_loop
from masked_dust3r.scripts.utils.math import *
from masked_dust3r.scripts.utils.image import *


DATA_PATH = "/dust3r/masked_dust3r/data/jackal_training_data_0"
IMG_FILE_EXTENSION = ".png"
MASK_FILE_EXTENSION = ".png"

INIT_FRAMES = 25
NEW_FRAMES = 5
PREVIOUS_FRAMES = 10
TOTAL_FRAMES = 300

INIT_WEIGHT_FOCAL = 0.1 * 0
INIT_WEIGHT_Z = 0.1 
INIT_WEIGHT_ROT = 0.01
INIT_WEIGHT_TRANS_SMOOTHNESS = 0.0001 
INIT_WEIGHT_ROT_SMOOTHNESS = 0.001 * 0

NEW_WEIGHT_FOCAL = 0.1 * 0
NEW_WEIGHT_Z = 0.1
NEW_WEIGHT_ROT = 0.01
NEW_WEIGHT_TRANS_SMOOTHNESS = 0.0001
NEW_WEIGHT_ROT_SMOOTHNESS = 0.00001 * 0

USE_COMMON_INTRINSICS = False
RESCALE_FACTOR  = 1280/512

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300 


with open(f"{DATA_PATH}/transforms_init.json") as f:
    transforms = json.load(f)

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [3]:
GAUSSIAN_SIGMA = 51.0
SIZE = int(GAUSSIAN_SIGMA * 3)

kernel = create_gaussian_kernel(SIZE, GAUSSIAN_SIGMA).to(device)

SIZE = 11
kernel = torch.ones(SIZE, SIZE).to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [4]:
# Load the model

model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

... loading model from checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth
instantiating : AsymmetricCroCo3DStereo(enc_depth=24, dec_depth=12, enc_embed_dim=1024, dec_embed_dim=768, enc_num_heads=16, dec_num_heads=12, pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', img_size=(512, 512), head_type='dpt', output_mode='pts3d', depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), landscape_only=False)
<All keys matched successfully>


In [5]:
for start_frame_index in range(INIT_FRAMES, TOTAL_FRAMES, NEW_FRAMES):
    images_array = []
    masks_array = []

    if USE_COMMON_INTRINSICS:
        preset_focal = [transforms["fl_x"]/RESCALE_FACTOR for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)]
    else:
        preset_focal = []

    preset_pose = []
    preset_mask = [True for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)]
    preset_mask[PREVIOUS_FRAMES:] = [False for _ in range(NEW_FRAMES)]

    for i in range(len(transforms["frames"])-PREVIOUS_FRAMES, len(transforms["frames"])):
        images_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["file_path"]))
        masks_array.append(os.path.join(DATA_PATH,transforms["frames"][i]["mask_path"]))
        preset_pose.append(np.array(transforms["frames"][i]["transform_matrix"]))
        if not USE_COMMON_INTRINSICS:
            preset_focal.append(transforms["frames"][i]["fl_x"]/RESCALE_FACTOR)
        print("Refering to {}...".format(transforms["frames"][i]["file_path"]))

    last_known_pose = preset_pose[-1]
    last_known_focal= preset_focal[-1]

    for i in range(start_frame_index, start_frame_index + NEW_FRAMES):
        images_array.append(os.path.join(DATA_PATH,"masked_images/{}{}".format(i,IMG_FILE_EXTENSION)))
        masks_array.append(os.path.join(DATA_PATH,"masks/{}{}".format(i,MASK_FILE_EXTENSION)))
        preset_pose.append(last_known_pose)
        if not USE_COMMON_INTRINSICS: preset_focal.append(last_known_focal)
        print("Estimating for {}...".format(os.path.join(DATA_PATH,"masked_images/{}{}".format(i,IMG_FILE_EXTENSION))))

    images = load_images(images_array, size=512, verbose=True)
    _,_,H,W = images[0]["img"].shape
    masks = load_masks(masks_array, H, W, device)
    break

Refering to masked_images/15.png...
Refering to masked_images/16.png...
Refering to masked_images/17.png...
Refering to masked_images/18.png...
Refering to masked_images/19.png...
Refering to masked_images/20.png...
Refering to masked_images/21.png...
Refering to masked_images/22.png...
Refering to masked_images/23.png...
Refering to masked_images/24.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/25.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/26.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/27.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/28.png...
Estimating for /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/29.png...
>> Loading a list of 15 images
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/15.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_tr

 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/22.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/23.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/24.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/25.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/26.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/27.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/28.png with resolution 1280x720 --> 512x288
 - adding /dust3r/masked_dust3r/data/jackal_training_data_0/masked_images/29.png with resolution 1280x720 --> 512x288
 (Found 15 images)


In [6]:
pairs = make_pairs(images, scene_graph='swin-{}'.format(5), prefilter=None, symmetrize=True)
output = inference_with_mask(pairs, model, device, masks, kernel, batch_size=batch_size)

>> Inference with model on 150 image pairs


100%|██████████| 150/150 [00:46<00:00,  3.25it/s]


In [7]:
scene = global_aligner(output, device=device, mode=GlobalAlignerMode.PlanePointCloudOptimizer, 
                        weight_focal = NEW_WEIGHT_FOCAL,
                        weight_z = NEW_WEIGHT_Z ,
                        weight_rot = NEW_WEIGHT_ROT ,
                        weight_trans_smoothness = NEW_WEIGHT_TRANS_SMOOTHNESS,
                        weight_rot_smoothness = NEW_WEIGHT_ROT_SMOOTHNESS)
print(scene.im_poses[-1])
if USE_COMMON_INTRINSICS :
    scene.preset_focal(preset_focal, [True for _ in range(PREVIOUS_FRAMES+NEW_FRAMES)])
else:
    scene.preset_focal(preset_focal, preset_mask)
scene.preset_pose(preset_pose, preset_mask)
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
print(scene.im_poses[-1])

Parameter containing:
tensor([-1.4003,  0.7331, -0.6836, -1.2265, -0.6846, -0.0236, -1.6716],
       device='cuda:0', requires_grad=True)
 (setting focal #0 = 459.58935546875)
 (setting focal #1 = 440.4442138671875)
 (setting focal #2 = 528.2886962890625)
 (setting focal #3 = 492.1719970703125)
 (setting focal #4 = 516.344970703125)
 (setting focal #5 = 506.750732421875)
 (setting focal #6 = 498.5870361328125)
 (setting focal #7 = 438.8687438964844)
 (setting focal #8 = 449.9277648925781)
 (setting focal #9 = 523.98876953125)
 (setting pose #0 = [-0.01574595  0.00103616  0.00583295])
 (setting pose #1 = [-0.00787305  0.0020128   0.00599961])
 (setting pose #2 = [0. 0. 0.])
 (setting pose #3 = [ 0.02126615 -0.00295147  0.00213417])
 (setting pose #4 = [ 0.04253045 -0.01279747  0.00425808])
 (setting pose #5 = [ 0.06041355 -0.02349862  0.00590006])
 (setting pose #6 = [ 0.07402549 -0.03513174  0.00733913])
 (setting pose #7 = [ 0.08753103 -0.04698689  0.00875222])
 (setting pose #8 = [ 0

 17%|█▋        | 50/300 [00:30<02:31,  1.66it/s, lr=0.00935613 loss=107.667]

In [None]:
all_focal = torch.stack(list(scene.im_focals))
all_poses = torch.stack(list(scene.im_poses))
Q = all_poses[:,:4]
Q = torch.nn.functional.normalize(Q, p=2, dim=1)
T = signed_expm1(all_poses[:,4:7])
tf = roma.RigidUnitQuat(Q, T).normalize()#.to_homogeneous()#.inverse()
tf_inv = tf.inverse()

off_z_axis = torch.tensor([]).to(device)

#print(tf.linear[:,3])

cnt = 0
for e, (i, j) in enumerate(scene.edges):
    cnt+=1
    #print(torch.nn.functional.normalize((tf_inv[i] @ tf[j]).linear[:3],p=2, dim=0, eps=1e-12,))
    print(torch.nn.functional.normalize((tf[j] @ tf_inv[i]).linear[:3],p=2, dim=0, eps=1e-12)[2].abs())
    #off_z_axis = torch.cat((off_z_axis, 1-torch.nn.functional.normalize((tf[j] @ tf_inv[i]).linear[:3],p=2, dim=0, eps=1e-12)[2].abs().unsqueeze(0)))

tensor(0.9995, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9996, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9996, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9999, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(1.0000, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(1.0000, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9996, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9998, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9998, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(1.0000, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9997, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(1.0000, device='cuda:0', grad_fn=<AbsBackward0>)
tensor(0.9998, device='cuda:0', grad_fn=<AbsBack

In [None]:
imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()
intrinsics = scene.get_intrinsics()

In [None]:
#Check if pointclouds folder exists
#If exists, delete all files in the folder
if os.path.exists("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
    for file in os.listdir("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
        os.remove("{DATA_PATH}/pointclouds/{file}".format(DATA_PATH=DATA_PATH, file=file))
        
if not os.path.exists("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH)):
    os.makedirs("{DATA_PATH}/pointclouds".format(DATA_PATH=DATA_PATH))

for i in range(len(images)):
    pointcloud = pts3d[i].detach().cpu().numpy()
    pointcloud = pointcloud.reshape(-1, 3)
    color = imgs[i].reshape(-1, 3)
    confidence_mask = confidence_masks[i].detach().cpu().numpy()
    confidence_mask = confidence_mask.reshape(-1)
    
    masked_pointcloud = []
    masked_color = []

    for j in range(len(confidence_mask)):
        if confidence_mask[j]:
            masked_pointcloud.append(pointcloud[j])
            masked_color.append(color[j])

    pcd = o3d.geometry.PointCloud()
    pcd.points = o3d.utility.Vector3dVector(masked_pointcloud)
    pcd.colors = o3d.utility.Vector3dVector(masked_color)
    o3d.io.write_point_cloud("{DATA_PATH}/pointclouds/pointcloud{i}.ply".format(DATA_PATH=DATA_PATH, i=i), pcd)

In [None]:
OPENGL = np.array([[1, 0, 0, 0],
                    [0, -1, 0, 0],
                    [0, 0, -1, 0],
                    [0, 0, 0, 1]])


for i in range(PREVIOUS_FRAMES, PREVIOUS_FRAMES+NEW_FRAMES):
    if not((confidence_masks[i]==0).all()):
        frame = {}
        frame["file_path"] = "/".join(images_array[i].split("/")[-2:])
        #frame["transform_matrix"] = np.linalg.inv(poses[i].detach().cpu().numpy()).tolist()
        frame["transform_matrix"] = np.dot(poses[i].detach().cpu().numpy(),OPENGL).tolist()
        frame["mask_path"] = "/".join(masks_array[i].split("/")[-2:])
        transforms["frames"].append(frame)
        
        if not USE_COMMON_INTRINSICS:
            frame["fl_x"] = intrinsics[i,0,0].item() * RESCALE_FACTOR
            frame["fl_y"] = intrinsics[i,1,1].item() * RESCALE_FACTOR
            frame["cx"] = intrinsics[i,0,2].item() * RESCALE_FACTOR
            frame["cy"] = intrinsics[i,1,2].item() * RESCALE_FACTOR
            img = Image.open(images_array[i])
            width, height = img.size
            transforms["w"] = width 
            transforms["h"] = height 

with open(f"{DATA_PATH}/transforms.json", "w") as f:
    json.dump(transforms, f, indent=4)