# 

In [6]:
%load_ext autoreload
%autoreload 2

import os
import sys
from pathlib import Path

import numpy as np
import torch

# dust3r_path = str(Path(os.path.join(os.getcwd())).parent / "dust3r" / "dust3r")
# sys.path.append(dust3r_path)
sys.path.append("../../dust3r")
sys.path.append("../..")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from src.annotated.dust3r.dust3r import AnnotatedAsymmetricCroCo3DStereo as AsymmetricCroCo3DStereo

ModuleNotFoundError: No module named 'models'

In [3]:
device = "cpu"
batch_size = 1
schedule = "cosine"
lr = 0.01
niter = 300

model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)

In [7]:
import sys

import numpy as np
import torch

sys.path.append("../..")

from src.dust3r.load_images import LoadConfig, load_images

In [5]:
image_paths = [
    "assets/house/IMG_0251.jpeg",
    "assets/house/IMG_0252.jpeg",
    "assets/house/IMG_0253.jpeg",
    "assets/house/IMG_0254.jpeg",
]
images_data = load_images(image_paths, config=LoadConfig(size=512))

INFO:src.load_images:>> Loading 4 images.
INFO:src.load_images: - Added assets/house/IMG_0251.jpeg with resolution 384x512
INFO:src.load_images: - Added assets/house/IMG_0252.jpeg with resolution 384x512
INFO:src.load_images: - Added assets/house/IMG_0253.jpeg with resolution 384x512
INFO:src.load_images: - Added assets/house/IMG_0254.jpeg with resolution 384x512
INFO:src.load_images: (Successfully loaded 4 images)


In [6]:
from src.make_pairs import make_pairs

In [7]:
pairs = make_pairs(images_data)

### Inference

In [1]:
from src.inference import collate_with_cat, inference, loss_of_one_batch

ModuleNotFoundError: No module named 'src'

In [9]:
batch_size = 2
dict_pairs = [(p[0].to_dict(), p[1].to_dict()) for p in pairs]
collated_dict_pairs = collate_with_cat(dict_pairs[0:batch_size])

In [10]:
output = loss_of_one_batch(collated_dict_pairs, model, criterion=None, device=device)

In [11]:
output = inference(pairs, model, device, batch_size=batch_size)

>> Inference with model on 12 image pairs


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 6/6 [05:11<00:00, 51.95s/it]


## Global Alignment

In [12]:
from src.optimizer import PointCloudOptimizer

view1, pred1 = output["view1"], output["pred1"]
view2, pred2 = output["view2"], output["pred2"]

In [13]:
# this is the relative distance computed up to a scale factor
print(torch.min(pred1["pts3d"]), torch.max(pred1["pts3d"]))
print(torch.min(pred2["pts3d_in_other_view"]), torch.max(pred2["pts3d_in_other_view"]))

tensor(-0.3400) tensor(0.8500)
tensor(-0.8543) tensor(0.8523)


In [14]:
# this is computed as conf = 1 + exp(conf_for_network_output)
print(torch.min(pred1["conf"]), torch.max(pred1["conf"]))
print(torch.min(pred2["conf"]), torch.max(pred2["conf"]))

tensor(1.0000) tensor(23.1831)
tensor(1.0000) tensor(14.2556)


In [15]:
scene = PointCloudOptimizer(view1, view2, pred1, pred2, device=device).to(device)

In [16]:
import roma
import torch.nn as nn
from src.minimum_spanning_tree import geotrf, get_med_dist_between_poses, init_minimum_spanning_tree
from src.utils import xy_grid, signed_expm1, signed_log1p

loss = scene.compute_global_alignment_v2(init="mst", niter=300, schedule=schedule, lr=lr)

 init edge (2*,3*) score=16.611814498901367
 init edge (1,2*) score=38.26947784423828
 init edge (0,1*) score=115.91004943847656
 init loss = 0.0010532001033425331
Global alignement - optimizing for:
['pw_poses', 'pw_adaptors', 'im_poses', 'im_depthmaps', 'im_focals']


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 300/300 [01:37<00:00,  3.07it/s, lr=1.27413e-06 loss=0.000220726]


In [None]:
scene.show(viewer="gl")

In [None]:
assert len(scene.imgs) == 2
assert scene.imgs[0].shape == (384, 512, 3)
assert len(scene.get_pts3d()) == 2
assert scene.get_pts3d()[0].shape == torch.Size([384, 512, 3])
assert len(scene.get_masks()) == 2
assert scene.get_masks()[0].shape == torch.Size([384, 512])

In [None]:
# retrieve useful values from scene:
imgs = scene.imgs
focals = scene.get_focals()
poses = scene.get_im_poses()
pts3d = scene.get_pts3d()
confidence_masks = scene.get_masks()

In [None]:
# find 2D-2D matches between the two images
from dust3r.utils.geometry import find_reciprocal_matches, xy_grid

pts2d_list, pts3d_list = [], []
for i in range(2):
    conf_i = confidence_masks[i].cpu().numpy()
    pts2d_list.append(xy_grid(*imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)
    pts3d_list.append(pts3d[i].detach().cpu().numpy()[conf_i])

In [None]:
assert imgs[0].shape[:2][::-1] == (512, 384)

assert confidence_masks[0].cpu().numpy().shape == (384, 512)
assert xy_grid(*imgs[0].shape[:2][::-1]).shape == (384, 512, 2)
assert pts3d[0].detach().cpu().numpy().shape == (384, 512, 3)

# number of points selected for 2d should be the same as the mask where pixels are True
assert np.sum(confidence_masks[0].cpu().numpy()) == len(pts2d_list[0])
assert np.sum(confidence_masks[0].cpu().numpy()) == len(pts3d_list[0])

assert imgs[1].shape[:2][::-1] == (512, 384)

assert confidence_masks[1].cpu().numpy().shape == (384, 512)
assert xy_grid(*imgs[1].shape[:2][::-1]).shape == (384, 512, 2)
assert pts3d[1].detach().cpu().numpy().shape == (384, 512, 3)

# number of points selected for 2d should be the same as the mask where pixels are True
assert np.sum(confidence_masks[1].cpu().numpy()) == len(pts2d_list[1])
assert np.sum(confidence_masks[1].cpu().numpy()) == len(pts3d_list[1])

In [None]:
reciprocal_in_P2, nn2_in_P1, num_matches = find_reciprocal_matches(*pts3d_list)
print(f"found {num_matches} matches")
matches_im1 = pts2d_list[1][reciprocal_in_P2]
matches_im0 = pts2d_list[0][nn2_in_P1][reciprocal_in_P2]

In [None]:
# visualize a few matches
import numpy as np
from matplotlib import pyplot as pl

n_viz = 10
match_idx_to_viz = np.round(np.linspace(0, num_matches - 1, n_viz)).astype(int)
viz_matches_im0, viz_matches_im1 = (
    matches_im0[match_idx_to_viz],
    matches_im1[match_idx_to_viz],
)

H0, W0, H1, W1 = *imgs[0].shape[:2], *imgs[1].shape[:2]
img0 = np.pad(imgs[0], ((0, max(H1 - H0, 0)), (0, 0), (0, 0)), "constant", constant_values=0)
img1 = np.pad(imgs[1], ((0, max(H0 - H1, 0)), (0, 0), (0, 0)), "constant", constant_values=0)
img = np.concatenate((img0, img1), axis=1)
pl.figure()
pl.imshow(img)
cmap = pl.get_cmap("jet")
for i in range(n_viz):
    (x0, y0), (x1, y1) = viz_matches_im0[i].T, viz_matches_im1[i].T
    pl.plot(
        [x0, x1 + W0],
        [y0, y1],
        "-+",
        color=cmap(i / (n_viz - 1)),
        scalex=False,
        scaley=False,
    )
pl.show(block=True)