In [None]:
from dust3r.inference import inference
from dust3r.model import AsymmetricCroCo3DStereo
from dust3r.utils.image import load_images
from dust3r.image_pairs import make_pairs
from dust3r.cloud_opt import global_aligner, GlobalAlignerMode
from dust3r.viz import show_raw_pointcloud,cat_3d
from BA import prepare_ba_options,dust3r_to_pycolmap,find_reciprocal_matches
from matplotlib import pyplot as pl
import cv2
import viz_3d
import pycolmap
import os
import open3d as o3d
import numpy as np
from dust3r.utils.geometry import xy_grid
import torch
from PIL import Image
from pathlib import Path
from hloc import visualization
import copy

device = 'cuda'
batch_size = 1
schedule = 'cosine'
lr = 0.01
niter = 300
########################
#vis config
########################
draw_scene_init = False
draw_scene_match_points = False

show_reconsturcion_init = True
show_reconsturcion_match = False
show_reconsturcion_ba = True

show_p2p_icp = False
show_reconsturcion_ba2 = False

show_3d_error_cluster = True
show_sam_mask = True
show_final_mask= True

show_reconsturcion_ba3 = True
show_sam_icp = True
show_pts_after_sam = True

#######################
#save config
#######################
save = False
res_dir_init = './exp/barn_baseline_6imgs/sparse/0'
res_dir_BA = './exp/barn_ba2_6imgs/sparse/0'
res_dir_BA_dense = './exp/barn_ba3_6imgs/sparse/0'
pcd_dir = './exp/barn_ba3_6imgs/pcds'

#######################
#pipe config
#######################
p2p_icp = False #有时候需要

model_name = "checkpoints/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth"
# you can put the path to a local checkpoint in model_name if needed
model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device)



In [None]:
name_list = [os.path.join('./images',i) for i in os.listdir('./images')][:6]           
# for i in name_list:
#     img = cv2.imread(i)
#     img = cv2.resize(img,(512,256))
#     img = cv2.imwrite(i,img)
images = load_images(name_list, size=512, square_ok=True)

pairs = make_pairs(images, scene_graph='complete', prefilter=None, symmetrize=True)
output = inference(pairs, model, device, batch_size=batch_size)

# at this stage, you have the raw dust3r predictions
view1, pred1 = output['view1'], output['pred1']
view2, pred2 = output['view2'], output['pred2']

scene = global_aligner(output, device=device, min_conf_thr=5,mode=GlobalAlignerMode.PointCloudOptimizer)
scene.preset_focal([316.3648]*len(name_list))
loss = scene.compute_global_alignment(init="mst", niter=niter, schedule=schedule, lr=lr)
focals = scene.get_focals()
avg_focal = sum(focals)/len(focals)

In [None]:
if draw_scene_init:
    viz_3d.draw_dust3r_scene(scene)
if draw_scene_match_points:
    viz_3d.draw_dust3r_match_ez(scene)

In [None]:
from colmap_builder_roma import *
pair_samples=10
pair_threshold=0.2

match_dict,warp_dict = exhaustive_matching_roma(name_list,pair_samples)
conf = 3
cb = colmap_builder(scene,name_list,warp_dict,match_dict,shared_cam=True)
#reconstruction = cb.pair_view_to_colmap()
cb.multi_view_to_pycolmap(3000,conf)
reconstruction = cb.add_points_to_colmap()
print(reconstruction.summary())

In [None]:
if show_reconsturcion_init:
    fig =viz_3d.init_figure()
    viz_3d.plot_reconstruction(
        fig, reconstruction, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
    )
    fig.show()
if show_reconsturcion_match:
    images = Path('./')
    visualization.visualize_sfm_2d(reconstruction, images, color_by="depth", n=2)
    

In [None]:
reconstruction_ba = copy.deepcopy(reconstruction)
# Prepare BA options
ba_options = prepare_ba_options(refine_focal_length=True)
# Conduct BA
pycolmap.bundle_adjustment(reconstruction_ba, ba_options)
ba_options = prepare_ba_options(refine_focal_length=False)
pycolmap.bundle_adjustment(reconstruction_ba, ba_options)
if show_reconsturcion_ba:
    fig =viz_3d.init_figure()
    viz_3d.plot_reconstruction(
        fig, reconstruction_ba, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
    )
    fig.show()


In [None]:
if save:
    if res_dir_init is not None:
        os.makedirs(res_dir_init,exist_ok=True)
        reconstruction.write(res_dir_init)
    if res_dir_BA is not None:
        os.makedirs(res_dir_BA,exist_ok=True)
        reconstruction_ba.write(res_dir_BA)

In [None]:
from p2p_icp import *
if p2p_icp:
    T_dict,res_dict,rmses,pts3d_ids = match_3d_points(reconstruction, reconstruction_ba)
    if show_p2p_icp:
        compare_pts(res_dict['source_t_all'],res_dict['target_all'])
    pts_origin, pts_trans = cb.transform_scene(T_dict)

In [None]:
conf = 2

cb.update_colmap(reconstruction_ba)
cb.multi_view_to_pycolmap(20000,conf)
reconstruction = cb.add_points_to_colmap()
reconstruction_ba2 = copy.deepcopy(reconstruction)
ba_options = prepare_ba_options(refine_focal_length=False,refine_extrinsics=False)
pycolmap.bundle_adjustment(reconstruction_ba2, ba_options)
if show_reconsturcion_ba2:
    fig =viz_3d.init_figure()
    viz_3d.plot_reconstruction(
        fig, reconstruction_ba2, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
    )
    fig.show()

In [None]:
T_dict,res_dict,rmses,pts3d_ids = match_3d_points(reconstruction, reconstruction_ba2)
pts_origin, pts_trans = cb.transform_scene(T_dict)

bad_match_indices,bad_pts,target = pcd_distance_filter(res_dict['source_t_all'],res_dict['target_all'],max(rmses),show=True)
labels = pcd_cluster(target,max(rmses),30,show=show_3d_error_cluster)

In [None]:
from sam_icp import *
group_ids = {i:[] for i in range(-1,labels.max()+1)}
#-1是噪声
pts3d_ids = np.asarray(pts3d_ids)
for i in range(-1,labels.max()+1):
    pts3d_id = pts3d_ids[bad_match_indices[labels==i]]
    group_ids[i] = pts3d_id

group_2ds = {i:[] for i in range(-1,labels.max()+1)}
#group_2ds[label][fidx][xy_ind]
for key in group_ids.keys():
    pt2d_dict,pt3d = search_3d_2d_corr(reconstruction_ba2,group_ids[key])
    group_2ds[key] = pt2d_dict

In [None]:
#sam processing
from sam_mask import *
#from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

In [None]:
masks_all_dict = {}
for label in group_2ds.keys():
    if label < 0:
        continue
    maskall = {}
    scoreall = {}
    avg_score = []
    for fidx in group_2ds[label].keys():
        pt2ds_test = group_2ds[label][fidx]
        img = scene.imgs[fidx]
        img = (img*255).astype(np.uint8)
        predictor.set_image(img)
        input_point = np.asarray(pt2ds_test)
        input_label = np.array([1]*len(input_point))
        masks, scores, logits = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=True,
        )
        #在视角内做面积删选，去掉过大的
        areas = [np.sum(masks[i]) for i in range(masks.shape[0])]
        avg_area = np.median(areas)
        masks = masks[areas<2*avg_area]
        scores = scores[areas<2*avg_area]
        
        avg_score.append(np.sum(np.asarray(scores)) / len(scores))
        maskall[fidx] = masks
        scoreall[fidx] = scores

    #avg_area = sum(avg_area) / len(avg_area)
    avg_area = np.median(avg_area)
    avg_score = sum(avg_score) / len(avg_score)

    good_mask = {}
    good_mask_score = {}
    for key in maskall.keys():
        masks,scores = maskall[key],scoreall[key]
        good_score = []
        for mask,score in zip(masks,scores):
            if score < avg_score and score < 0.9:
                continue
            good_score.append(score)
        if good_score:
            best_socre = max(good_score)
            ind = list(good_score).index(best_socre)
            mask = find_largest_connected_component(masks[ind].astype(np.uint8))
            mask[mask>0] = 1
            good_mask[key] = mask
            good_mask_score[key] = best_socre

    for fidx,mask in good_mask.items():
        pt2ds_test = group_2ds[label][fidx]
        predictor.set_image(img)
        input_point = np.asarray(pt2ds_test)
        input_label = np.array([1]*len(input_point))
        img = scene.imgs[fidx]
        img = (img*255).astype(np.uint8)
        if show_sam_mask:
            pl.figure(figsize=(10,10))
            pl.imshow(img)
            show_mask(mask, pl.gca())
            show_points(input_point, input_label, pl.gca(),marker_size=20)
            pl.title(f"Mask {fidx}, Score: {good_mask_score[fidx]:.3f}", fontsize=18)
            pl.axis('off')
            pl.show() 
    masks_all_dict[label] = good_mask
final_mask = backproject_mask(cb,masks_all_dict)
if show_final_mask:
    for i in range(len(cb.imgs)):
        img = cb.imgs[i]
        mask = np.stack([final_mask[i]]*3,2) 
        
        pl.imshow(img*mask)
        pl.show()

In [None]:
#对error区域dense！
conf = 2
cb.update_colmap(reconstruction_ba)

cb.multi_view_to_pycolmap(20000,conf)
conf = 1
cb.multi_view_to_pycolmap(20000,conf,final_mask)
reconstruction = cb.add_points_to_colmap()

reconstruction_ba3 = copy.deepcopy(reconstruction)
ba_options = prepare_ba_options(refine_focal_length=False,refine_extrinsics=False)

pycolmap.bundle_adjustment(reconstruction_ba3, ba_options)
if save:
    if res_dir_BA_dense is not None:
        os.makedirs(res_dir_BA_dense,exist_ok=True)
        reconstruction_ba3.write(res_dir_BA_dense)
    if show_reconsturcion_ba3:
        fig =viz_3d.init_figure()
        viz_3d.plot_reconstruction(
            fig, reconstruction_ba3, color="rgba(255,0,0,0.5)", name="mapping", points_rgb=True
        )
        fig.show()

In [None]:
res,mask_t_pts_all,mask_t_pts_all_rgb,source_t_pts_all,target_pts_all = sam_icp_local(reconstruction_ba2,masks_all_dict,cb.pts3d_all,scene.imgs,show=show_sam_icp)


In [None]:
cb_new = copy.deepcopy(cb)
pts_all = update_scene_pts(cb_new,res,masks_all_dict)
pts_all = backproject_colamp_points(reconstruction_ba3,cb_new)
if show_pts_after_sam:
    viz_3d.draw_dust3r_scene(scene,pose_refine = cb_new.poses,pts3d = cb_new.pts3d_all)

In [None]:
from depth_filter import *
if save:
    pts2d_all_list = []
    pts3d_all_list = []
    all_rgbs_conf = []
    all_rgbs_list = []
    for i in range(len(cb_new.imgs)):
        conf_i = cb.confidence_masks[i]
        pts2d_all_list.append(xy_grid(*cb_new.imgs[i].shape[:2][::-1])[conf_i])  # imgs[i].shape[:2] = (H, W)
        pts3d_all_list.append(cb_new.pts3d_all[i].reshape((-1,3)))
        all_rgbs_list.append(cb_new.imgs[i].reshape((-1,3)))
        #pts3d_all_list.append(cb_new.pts3d_all[i][conf_i])
        
        for pos in pts2d_all_list[i]:
            x,y = pos
            all_rgbs_conf.append(cb.imgs[i][y,x,:])
    os.makedirs(pcd_dir,exist_ok=True)
    all_pts3d = np.concatenate(pts3d_all_list,axis=0)
    all_rgbs = np.concatenate(all_rgbs_list,axis=0)
    all_pts3d_conf = np.concatenate(cb_new.pts3d_list,axis=0)
    all_rgbs_conf = np.asarray(all_rgbs_conf)
    # 创建Open3D点云对象
    pcd_conf = to_pcd(all_pts3d_conf,all_rgbs_conf)
    output_file = os.path.join(pcd_dir,"output_point_cloud_conf.ply")
    o3d.io.write_point_cloud(output_file, pcd_conf)

    pcd = to_pcd(all_pts3d,all_rgbs)
    output_file = os.path.join(pcd_dir,"output_point_cloud_all.ply")
    o3d.io.write_point_cloud(output_file, pcd)

    pts3d_all_filted_list,mask_all_filted_list = depth_filter(cb_new.pts3d_all,cb_new.confidence_masks,cb_new.poses,focals,cb_new.W//2,cb_new.H//2)

    pts3d_all_filted_list = [pts3d_all_filted_list[i][mask_all_filted_list[i]] for i in range(len(pts3d_all_filted_list))]
    pts3d_filted = np.concatenate(pts3d_all_filted_list,axis=0)

    rgbs_filted = [all_rgbs_list[i][mask_all_filted_list[i]] for i in range(len(all_rgbs_list))]
    rgbs_filted = np.concatenate(rgbs_filted,axis=0)
    pcd_filted = to_pcd(pts3d_filted,rgbs_filted)
    output_file = os.path.join(pcd_dir,"output_point_cloud_filted.ply")
    o3d.io.write_point_cloud(output_file, pcd_filted)
    #depth filters是合理的！但是对有些细的地方会有破坏
    #######################这部分后面调整也行，参数变化大#########################
    voxel_size = 0.0008  # 体素大小，可以根据需要调整
    pcd_filted_down = pcd_filted.voxel_down_sample(voxel_size)
    # cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
    # pcd_filtered = pcd.select_by_index(ind)

    # # 使用半径滤波减少噪声和离群点
    pcd_filted_down, ind = pcd_filted_down.remove_radius_outlier(nb_points=4, radius=0.00135)

    pcd_filted_down_c = pcd_filted_down + pcd_conf

    voxel_size = 0.0008  # 体素大小，可以根据需要调整
    pcd_filted_down_c = pcd_filted_down_c.voxel_down_sample(voxel_size)

    output_file = os.path.join(pcd_dir,"output_point_cloud_filted_conf.ply")
    o3d.io.write_point_cloud(output_file, pcd_filted_down_c)

In [None]:
print(name_list)

def find_top2_indices(lst):
    if len(lst) < 2:
        raise ValueError("List should contain at least two elements.")
    
    # 找到第一个最大的值及其索引
    max1_index = lst.index(max(lst))
    max1_value = lst[max1_index]
    
    # 临时将第一个最大值设置为负无穷大，以找到第二个最大值
    lst[max1_index] = float('-inf')
    
    # 找到第二个最大的值及其索引
    max2_index = lst.index(max(lst))
    max2_value = lst[max2_index]
    
    # 恢复第一个最大值
    lst[max1_index] = max1_value
    
    return max1_index, max2_index
for i in range(len(name_list)):
    match_pts = []
    inds = []
    for j in range(len(name_list)):
        if i == j:
            continue
        key = f'{i}-{j}'
        match_pts.append(cb_new.matching_dict[key][0].sum())
        inds.append(j)
    max1_index, max2_index = find_top2_indices(match_pts)
    print(f'{name_list[i]} match {name_list[inds[max1_index]]}')
    #print(f'{name_list[i]} match {name_list[inds[max2_index]]}')