In [1]:
%matplotlib inline

import os
import sys
sys.path.append('../')
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

import glob
import random
import cv2
import numpy as np
import networkx as nx
import torch
import torch.nn.functional as F

import io
import evo
import evo.main_ape as main_ape
import evo.main_rpe as main_rpe

from tqdm import tqdm
from evo.core.metrics import PoseRelation, Unit
from evo.core.trajectory import PoseTrajectory3D
from evo.core import lie_algebra
from evo.tools.plot import PlotMode
from copy import deepcopy
from scipy.spatial.transform import Rotation
from PIL import Image
from matplotlib import pyplot as plt

from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images_ratio, load_and_preprocess_images_square
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
from vggt.utils.helper import create_pixel_coordinate_grid, randomly_limit_trues
from vggt.dependency.track_predict import predict_tracks
from vggt.dependency.np_to_pycolmap import batch_np_matrix_to_pycolmap, batch_np_matrix_to_pycolmap_wo_track

from utils.umeyama import umeyama

device = "cuda" if torch.cuda.is_available() else "cpu"
# bfloat16 is supported on Ampere GPUs (Compute Capability 8.0+) 
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16

In [2]:
def run_VGGT(model, images, dtype, resolution=518, track_feat=False):
    # images: [B, 3, H, W]

    assert len(images.shape) == 4
    assert images.shape[1] == 3

    device = next(model.parameters()).device
    images = images.to(device)
    
    with torch.no_grad():
        with torch.cuda.amp.autocast(dtype=dtype):
            images = images[None]  # add batch dimension
            valid_layers = model.depth_head.intermediate_layer_idx
            if valid_layers[-1] != model.aggregator.aa_block_num - 1:
                valid_layers.append(model.aggregator.aa_block_num - 1)
            aggregated_tokens_list, ps_idx = model.aggregator(images, valid_layers)
            aggregated_tokens_list = [tokens.to(device) if tokens is not None else None for tokens in aggregated_tokens_list]

        # Predict Cameras
        pose_enc = model.camera_head(aggregated_tokens_list)[-1]
        # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
        extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, images.shape[-2:])
        # Predict Depth Maps
        depth_map, depth_conf = model.depth_head(aggregated_tokens_list, images, ps_idx)

        extrinsic = extrinsic.squeeze(0).cpu().numpy()
        intrinsic = intrinsic.squeeze(0).cpu().numpy()
        depth_map = depth_map.squeeze(0).cpu().numpy()
        depth_conf = depth_conf.squeeze(0).cpu().numpy()

        track_feature_maps = None if not track_feat else model.track_head.feature_extractor(aggregated_tokens_list, images, ps_idx)
        
    return extrinsic, intrinsic, depth_map, depth_conf, track_feature_maps


In [3]:
# Initialize the model and load the pretrained weights.
# This will automatically download the model weights the first time it's run, which may take a while.
model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
model.eval()

VGGT(
  (aggregator): Aggregator(
    (patch_embed): DinoVisionTransformer(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14))
        (norm): Identity()
      )
      (blocks): ModuleList(
        (0-23): 24 x NestedTensorBlock(
          (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (attn): MemEffAttention(
            (qkv): Linear(in_features=1024, out_features=3072, bias=True)
            (q_norm): Identity()
            (k_norm): Identity()
            (attn_drop): Dropout(p=0.0, inplace=False)
            (proj): Linear(in_features=1024, out_features=1024, bias=True)
            (proj_drop): Dropout(p=0.0, inplace=False)
          )
          (ls1): LayerScale()
          (drop_path1): Identity()
          (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
          (mlp): Mlp(
            (fc1): Linear(in_features=1024, out_features=4096, bias=True)
            (act): GELU(approximate=

In [4]:
import utils.colmap as colmap_utils

# Get image paths and preprocess them
sparse_dir_gt = "../data/MipNeRF360/treehill/sparse/0"
sparse_dir_pred = "../data/MipNeRF360_vggt/treehill/sparse/0"
images_dir = "../data/MipNeRF360/treehill/images"

cameras_gt = colmap_utils.read_cameras_binary(os.path.join(sparse_dir_gt, "cameras.bin"))
images_gt = colmap_utils.read_images_binary(os.path.join(sparse_dir_gt, "images.bin"))
pcd_gt = colmap_utils.read_points3D_binary(os.path.join(sparse_dir_gt, "points3D.bin"))
# images_gt = dict(sorted(images_gt.items(), key=lambda item: item[0]))

# cameras_pred = colmap_utils.read_cameras_binary(os.path.join(sparse_dir_pred, "cameras.bin"))
# images_pred = colmap_utils.read_images_binary(os.path.join(sparse_dir_pred, "images.bin"))
# pcd_pred = colmap_utils.read_points3D_binary(os.path.join(sparse_dir_pred, "points3D.bin"))
# images_pred = dict(sorted(images_pred.items(), key=lambda item: item[0]))

In [5]:
images_gt_updated = {id: images_gt[id] for id in list(images_gt.keys())}
image_path_list = [os.path.join(images_dir, images_gt_updated[id].name) for id in images_gt_updated.keys()]
base_image_path_list = [os.path.basename(path) for path in image_path_list]
total_frame_num = len(image_path_list)

vggt_fixed_resolution = 518
images, original_coords = load_and_preprocess_images_ratio(image_path_list, vggt_fixed_resolution)

# plot images
# plt.figure(figsize=(16, 10))
# for i, img in enumerate(images):
#     plt.subplot(5, 5, i + 1)
#     plt.imshow(img.permute(1, 2, 0).cpu().numpy())
#     plt.title(base_image_path_list[i].split('_Zenmuse')[0])
#     plt.axis('off')

In [6]:
# Run VGGT to estimate camera and depth
# Run with 518x518 images
extrinsic, intrinsic, depth_map, depth_conf, track_feats = run_VGGT(model, images, dtype, vggt_fixed_resolution, track_feat=True)
points_3d = unproject_depth_map_to_point_map(depth_map, extrinsic, intrinsic)
torch.cuda.empty_cache()

In [7]:
conf_threshold = 2
num_track_pts = 50
mask = depth_conf > conf_threshold

# idx = random.randint(0, total_frame_num - 1)
idx = 0
image_masked = images[idx].cpu().numpy().transpose(1, 2, 0) * mask[idx][:, :, None]
image_org = images[idx].cpu().numpy().transpose(1, 2, 0)

plt.style.use("seaborn-v0_8-whitegrid")
plt.figure(figsize=(10, 5))
plt.subplot(1, 2, 1)
plt.imshow(image_org)
plt.title(f"Original Image of {idx}")
plt.tight_layout()
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(image_masked)
plt.title("Masked Image with Depth Confidence > {}".format(conf_threshold))
# pick up num_track_pts points with highest depth confidence and mark them as red x on the image
# mark_points = np.argsort(depth_conf[idx].flatten())[-num_track_pts:]
# randomly select num_track_pts points with highest depth confidence and mark them as red x on the image
valid_index = np.where(mask[idx].flatten())[0].tolist()
mark_points = random.sample(valid_index, num_track_pts)
for point in mark_points:
    y, x = divmod(point, depth_conf[idx].shape[1])
    plt.scatter(x, y, color='red', s=10, marker='x')

plt.tight_layout()
plt.axis('off')

  plt.tight_layout()


(-0.5, 517.5, 335.5, -0.5)

In [8]:
ANALYZE_EXAMPLE=False

if ANALYZE_EXAMPLE:
    # Predict Tracks
    # choose your own points to track, with shape (N, 2) for one scene
    from vggt.utils.visual_track import visualize_tracks_on_images

    conf_threshold = 2
    num_track_pts = 50
    mask = depth_conf > conf_threshold

    start_idx = 0
    corr_mask = np.zeros(depth_conf.shape[0], dtype=bool)

    with torch.no_grad():
        torch.cuda.empty_cache()
        aggregated_tokens_list, ps_idx = run_VGGT(model, images, dtype, vggt_fixed_resolution, feat_only=True)

    while any(corr_mask) is False:

        query_points_list = []
        valid_index = np.where(mask[start_idx].flatten())[0].tolist()
        mark_points = random.sample(valid_index, num_track_pts)
        for point in mark_points:
            y, x = divmod(point, depth_conf[start_idx].shape[1])
            query_points_list.append([x, y])
        query_points = torch.FloatTensor(query_points_list).to(device)

        # reorder the image, make the start_idx image the first one
        reordered_idx = list(range(start_idx, total_frame_num)) + list(range(0, start_idx))
        reordered_aggregated_tokens_list = [aggregated_tokens_list[i] if aggregated_tokens_list[i] is None \
                                            else aggregated_tokens_list[i][:, reordered_idx] for i in range(len(aggregated_tokens_list))]

        with torch.no_grad():
            track_list, vis_score, conf_score = model.track_head(reordered_aggregated_tokens_list, images[None, reordered_idx], 
                                                                ps_idx, query_points=query_points[None])
            valid_track_score_mask = (conf_score > 0.2) & (vis_score > 0.2)
            valid_track_num = valid_track_score_mask.sum(dim=-1)
            valid_track_num_mask = valid_track_num > num_track_pts // 2
            valid_frame_idx = torch.where(valid_track_num_mask[0])[0].tolist()

    visualize_tracks_on_images(images, track_list[-1], (conf_score>0.2) & (vis_score>0.2), out_dir="track_visuals")

In [9]:
# Predict Tracks
# choose your own points to track, with shape (N, 2) for one scene
from vggt.utils.visual_track import visualize_tracks_on_images

# delete the track_visuals directory if it exists
if os.path.exists("track_visuals"):
    import shutil
    shutil.rmtree("track_visuals")

conf_threshold = 2.0
max_num_track_pts = 100
query_frame_num = total_frame_num // 2
mask = depth_conf > conf_threshold

corr_mask = np.zeros(depth_conf.shape[0], dtype=bool)
rest_frame_idx = np.where(~corr_mask)[0].tolist()

tracks_list = []
vis_scores_list = []
conf_scores_list = []
frame_idx_list = []
tgt_idx_list = []
iteration = 0

while all(corr_mask) is False:

    query_points_list = []
    if len(rest_frame_idx) > 0:
        start_idx = rest_frame_idx.pop(0)
    else:
        break
    valid_index = np.where(mask[start_idx].flatten())[0].tolist()
    # num_track_pts = min(max_num_track_pts, len(valid_index))
    num_track_pts = max_num_track_pts
    mark_points = np.random.choice(valid_index, num_track_pts, replace=True).tolist()
    for point in mark_points:
        y, x = divmod(point, depth_conf[start_idx].shape[1])
        query_points_list.append([x, y])
    
    if len(query_points_list) == 0:
        print(f"No valid points found for frame {start_idx}, skipping...")
        continue

    query_points = torch.FloatTensor(query_points_list).to(device)

    # reorder the image, make the start_idx image the first one
    reordered_idx = list(range(start_idx, total_frame_num)) + list(range(0, start_idx))
    # reordered_idx = [start_idx] + random.sample(reordered_idx[1:], query_frame_num - 1)
    with torch.no_grad():
        print(f"Number of query points: {len(query_points_list)}, start index: {start_idx}, {sum(~corr_mask)} frames are rest")
        track_list, vis_score, conf_score = model.track_head.tracker(query_points=query_points[None], fmaps=track_feats[:, reordered_idx], iters=model.track_head.iters)
        valid_track_score_mask = (conf_score > 0.2) & (vis_score > 0.2)
        valid_track_num = valid_track_score_mask.sum(dim=-1)
        valid_track_num_mask = valid_track_num > num_track_pts // 4
        valid_idx = torch.where(valid_track_num_mask[0])[0].tolist()

        if len(valid_idx) <= 1:
            print(f"No valid tracks found for frame {start_idx}, skipping...")
            continue

        tracks_list.append(track_list[-1][0, valid_idx].cpu().numpy())
        vis_scores_list.append(vis_score[0, valid_idx].cpu().numpy())
        conf_scores_list.append(conf_score[0, valid_idx].cpu().numpy())
        tgt_idx_list += [len(frame_idx_list)] * len(valid_idx)
        valid_frame_idx = [reordered_idx[i] for i in valid_idx]
        frame_idx_list += valid_frame_idx
        
    corr_mask[valid_frame_idx] = True
    rest_frame_idx = np.where(~corr_mask)[0].tolist()
    
    visualize_tracks_on_images(images[reordered_idx][valid_idx], track_list[-1][:, valid_idx], valid_track_score_mask[:, valid_idx], out_dir=f"track_visuals/{iteration:04d}_start_{start_idx:04d}")

print(f"Total {len(frame_idx_list)} correspondence pairs are predicted.")
print(f"Total {sum(~corr_mask)} frames are find no correspondence.")
# visualize_tracks_on_images(images, track_list[-1], (conf_score>0.2) & (vis_score>0.2), out_dir="track_visuals")

Number of query points: 100, start index: 0, 141 frames are rest


[INFO] Saved color-by-XY track visualization grid -> track_visuals/0000_start_0000/tracks_grid.png
[INFO] Saved 72 individual frames to track_visuals/0000_start_0000/frame_*.png
Number of query points: 100, start index: 9, 69 frames are rest
[INFO] Saved color-by-XY track visualization grid -> track_visuals/0000_start_0009/tracks_grid.png
[INFO] Saved 36 individual frames to track_visuals/0000_start_0009/frame_*.png
Number of query points: 100, start index: 14, 60 frames are rest
[INFO] Saved color-by-XY track visualization grid -> track_visuals/0000_start_0014/tracks_grid.png
[INFO] Saved 3 individual frames to track_visuals/0000_start_0014/frame_*.png
Number of query points: 100, start index: 15, 58 frames are rest
[INFO] Saved color-by-XY track visualization grid -> track_visuals/0000_start_0015/tracks_grid.png
[INFO] Saved 15 individual frames to track_visuals/0000_start_0015/frame_*.png
Number of query points: 100, start index: 18, 49 frames are rest
No valid tracks found for fram

In [10]:
import kornia
from scipy.ndimage import map_coordinates

image_names = np.array([base_image_path_list[i] for i in frame_idx_list])
tracks = np.concatenate(tracks_list, axis=0)
vis_scores = np.concatenate(vis_scores_list, axis=0)
conf_scores = np.concatenate(conf_scores_list, axis=0)
target_indexes = np.array(tgt_idx_list)
frame_indexes = np.array(frame_idx_list)

frame_indexes = frame_indexes[:, None].repeat(tracks.shape[1], axis=1)
frame_names = image_names[:, None].repeat(tracks.shape[1], axis=1)
valid_track_score_mask = (conf_scores > 0.2) & (vis_scores > 0.2)

corr_points_i = tracks[valid_track_score_mask]
corr_points_j = tracks[target_indexes][valid_track_score_mask]

image_names_i = frame_names[valid_track_score_mask]
image_names_j = frame_names[target_indexes][valid_track_score_mask]

frame_indexes_i = frame_indexes[valid_track_score_mask]
frame_indexes_j = frame_indexes[target_indexes][valid_track_score_mask]

vis_scores_i = vis_scores[valid_track_score_mask]
vis_scores_j = vis_scores[target_indexes][valid_track_score_mask]

conf_scores_i = conf_scores[valid_track_score_mask]
conf_scores_j = conf_scores[target_indexes][valid_track_score_mask]

same_pt_mask = (image_names_i == image_names_j)
corr_points_i = corr_points_i[~same_pt_mask]
corr_points_j = corr_points_j[~same_pt_mask]
image_names_i = image_names_i[~same_pt_mask]
image_names_j = image_names_j[~same_pt_mask]
frame_indexes_i = frame_indexes_i[~same_pt_mask]
frame_indexes_j = frame_indexes_j[~same_pt_mask]
vis_scores_i = vis_scores_i[~same_pt_mask]
vis_scores_j = vis_scores_j[~same_pt_mask]
conf_scores_i = conf_scores_i[~same_pt_mask]
conf_scores_j = conf_scores_j[~same_pt_mask]

# depths_i = depth_map[frame_indexes_i, np.around(corr_points_i[:, 1]).astype(int), np.around(corr_points_i[:, 0]).astype(int)]
# depths_j = depth_map[frame_indexes_j, np.around(corr_points_j[:, 1]).astype(int), np.around(corr_points_j[:, 0]).astype(int)]
# corr_points_i = np.concatenate([corr_points_i, depths_i], axis=1)
# corr_points_j = np.concatenate([corr_points_j, depths_j], axis=1)

intrinsic_i = np.zeros((corr_points_i.shape[0], 4, 4), dtype=np.float32)
intrinsic_j = np.zeros((corr_points_j.shape[0], 4, 4), dtype=np.float32)
intrinsic_i[:, :3, :3] = intrinsic[frame_indexes_i]
intrinsic_j[:, :3, :3] = intrinsic[frame_indexes_j]
intrinsic_i[:, 3, 3] = 1.0
intrinsic_j[:, 3, 3] = 1.0

extrinsic_i = np.zeros((corr_points_i.shape[0], 4, 4), dtype=np.float32)
extrinsic_j = np.zeros((corr_points_j.shape[0], 4, 4), dtype=np.float32)
extrinsic_i[:, :3, :4] = extrinsic[frame_indexes_i]
extrinsic_j[:, :3, :4] = extrinsic[frame_indexes_j]
extrinsic_i[:, 3, 3] = 1.0
extrinsic_j[:, 3, 3] = 1.0

corr_points_i_tensor = torch.FloatTensor(corr_points_i).to(device)
corr_points_j_tensor = torch.FloatTensor(corr_points_j).to(device)
weight_i = torch.FloatTensor(vis_scores_i * conf_scores_i).to(device)
weight_j = torch.FloatTensor(vis_scores_j * conf_scores_j).to(device)
intrinsic_i_tensor = torch.FloatTensor(intrinsic_i).to(device)
intrinsic_j_tensor = torch.FloatTensor(intrinsic_j).to(device)
extrinsic_i_tensor = torch.FloatTensor(extrinsic_i).to(device)
extrinsic_j_tensor = torch.FloatTensor(extrinsic_j).to(device)

with torch.no_grad():
    P_i = intrinsic_i_tensor @ extrinsic_i_tensor
    P_j = intrinsic_j_tensor @ extrinsic_j_tensor
    Fm = kornia.geometry.epipolar.fundamental_from_projections(P_i[:, :3], P_j[:, :3])
    err = kornia.geometry.symmetrical_epipolar_distance(corr_points_i_tensor[:, None, :2], corr_points_j_tensor[:, None, :2], Fm, squared=False, eps=1e-08)
    weight = torch.sqrt(weight_i * weight_j)
    err = err * weight[:, None] / weight.mean()

# show the error distribution
plt.figure(figsize=(10, 5))
plt.hist(err.cpu().numpy(), bins=100, density=True)
plt.title(f"Symmetrical Epipolar Distance Distribution of {len(err)} correspondences, average: {err.mean().item():.3f}, std: {err.std().item():.3f}")
plt.xlabel("Distance")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

sampled_i = torch.zeros((corr_points_i_tensor.shape[0], 3), dtype=images.dtype, device=images.device)
sampled_j = torch.zeros((corr_points_j_tensor.shape[0], 3), dtype=images.dtype, device=images.device)

with torch.no_grad():
    corr_points_i_normalized = corr_points_i_tensor.to(images.device) / original_coords[frame_indexes_i][:, 2:4] * 2 - 1
    corr_points_j_normalized = corr_points_j_tensor.to(images.device) / original_coords[frame_indexes_j][:, 2:4] * 2 - 1

    for frame_idx in tqdm(np.unique(frame_indexes_i)):
        sampled_i[frame_indexes_i==frame_idx] = F.grid_sample(
            images[frame_idx].unsqueeze(0),
            corr_points_i_normalized[frame_indexes_i==frame_idx][None, None],
            align_corners=True,
            mode='bilinear'
        ).squeeze().permute(1, 0)

    for frame_idx in tqdm(np.unique(frame_indexes_j)):
        sampled_j[frame_indexes_j==frame_idx] = F.grid_sample(
            images[frame_idx].unsqueeze(0),
            corr_points_j_normalized[frame_indexes_j==frame_idx][None, None],
            align_corners=True,
            mode='bilinear'
        ).squeeze().permute(1, 0)

    err_rgb = torch.norm(sampled_i - sampled_j, dim=-1)
# show the error distribution
plt.figure(figsize=(10, 5))
plt.hist(err_rgb.cpu().numpy(), bins=100, density=True)
plt.title(f"RGB Distance Distribution of {len(err_rgb)} correspondences, average: {err_rgb.mean().item():.3f}, std: {err_rgb.std().item():.3f}")
plt.xlabel("Distance")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

  plt.tight_layout()
100%|██████████| 131/131 [00:00<00:00, 3854.22it/s]
100%|██████████| 19/19 [00:00<00:00, 1794.78it/s]
  plt.tight_layout()


In [11]:
# align with 3DGS training
intrinsic_i_normalized = intrinsic_i.copy()
intrinsic_j_normalized = intrinsic_j.copy()
intrinsic_i_normalized[:, :2] /= original_coords[frame_indexes_i][:, 2:4, None]
intrinsic_j_normalized[:, :2] /= original_coords[frame_indexes_j][:, 2:4, None]

image_path_list_track = [os.path.join(images_dir+"_4", images_gt_updated[id].name) for id in images_gt_updated.keys()]
images, original_coords_track = load_and_preprocess_images_ratio(image_path_list_track)

intrinsic_i_tensor_ = torch.FloatTensor(intrinsic_i_normalized)
intrinsic_j_tensor_ = torch.FloatTensor(intrinsic_j_normalized)
intrinsic_i_tensor_[:, :2] *= original_coords_track[frame_indexes_i][:, 2:4, None]
intrinsic_j_tensor_[:, :2] *= original_coords_track[frame_indexes_j][:, 2:4, None]

In [12]:
# align with 3DGS training
intrinsic_i_normalized = intrinsic_i.copy()
intrinsic_j_normalized = intrinsic_j.copy()
intrinsic_i_normalized[:, :2] /= original_coords[frame_indexes_i][:, 2:4, None]
intrinsic_j_normalized[:, :2] /= original_coords[frame_indexes_j][:, 2:4, None]

intrinsic_i_tensor_ = torch.FloatTensor(intrinsic_i_normalized)
intrinsic_j_tensor_ = torch.FloatTensor(intrinsic_j_normalized)
intrinsic_i_tensor_[:, :2] *= original_coords_track[frame_indexes_i][:, 2:4, None]
intrinsic_j_tensor_[:, :2] *= original_coords_track[frame_indexes_j][:, 2:4, None]
intrinsic_i_tensor_ = intrinsic_i_tensor_.to(intrinsic_i_tensor.device)
intrinsic_j_tensor_ = intrinsic_j_tensor_.to(intrinsic_j_tensor.device)

In [20]:
search_radius = 3  # search radius in pixels
S, C, H, W = images.shape
sampled_i = torch.zeros((corr_points_i_normalized.shape[0], 3), dtype=images.dtype, device=images.device)
sampled_j = torch.zeros((corr_points_j_normalized.shape[0], 3), dtype=images.dtype, device=images.device)

corr_points_i_normalized_ = torch.zeros_like(corr_points_i_normalized, device=images.device)

for frame_idx in tqdm(np.unique(frame_indexes_j)):
    sampled_j[frame_indexes_j==frame_idx] = F.grid_sample(
        images[frame_idx].unsqueeze(0),
        corr_points_j_normalized[frame_indexes_j==frame_idx][None, None],
        align_corners=True,
        mode='bilinear'
    ).squeeze().permute(1, 0)

for frame_idx in tqdm(np.unique(frame_indexes_i)):
    offsets = torch.stack(torch.meshgrid(
        torch.arange(-search_radius * 2 / W, (search_radius + 1) * 2 / W, 2 / W, device=images.device),
        torch.arange(-search_radius * 2 / H, (search_radius + 1) * 2 / H, 2 / H, device=images.device),
        indexing='xy'
    ), dim=-1).view(-1, 2)
    
    corr_points_i_normalized_offset = corr_points_i_normalized[frame_indexes_i==frame_idx][:, None, :] + offsets[None, :, :]

    sampled_i_temp = F.grid_sample(
        images[frame_idx].unsqueeze(0),
        corr_points_i_normalized_offset[None],
        align_corners=True,
        mode='bilinear'
    ).squeeze(0).permute(1, 2, 0)

    rgb_diff = torch.norm(sampled_i_temp - sampled_j[frame_indexes_i==frame_idx][:, None, :], dim=-1)
    min_diff_idx = rgb_diff.argmin(dim=1)

    sampled_i[frame_indexes_i==frame_idx] = sampled_i_temp[torch.arange(sampled_i_temp.shape[0]), min_diff_idx]
    corr_points_i_normalized_[frame_indexes_i==frame_idx] = corr_points_i_normalized_offset[torch.arange(corr_points_i_normalized_offset.shape[0]), min_diff_idx]


with torch.no_grad():
    P_i = intrinsic_i_tensor_ @ extrinsic_i_tensor
    P_j = intrinsic_j_tensor_ @ extrinsic_j_tensor
    
    corr_points_i_tensor_ = (corr_points_i_normalized_ + 1) * original_coords_track[frame_indexes_i][:, 2:4] / 2
    corr_points_j_tensor_ = (corr_points_j_normalized + 1) * original_coords_track[frame_indexes_j][:, 2:4] / 2
    corr_points_i_tensor_ = corr_points_i_tensor_.to(P_i.device)
    corr_points_j_tensor_ = corr_points_j_tensor_.to(P_j.device)

    Fm = kornia.geometry.epipolar.fundamental_from_projections(P_i[:, :3], P_j[:, :3])
    err = kornia.geometry.symmetrical_epipolar_distance(corr_points_i_tensor_[:, None, :2], corr_points_j_tensor_[:, None, :2], Fm, squared=False, eps=1e-08)
    weight = torch.sqrt(weight_i * weight_j)
    err = err * weight[:, None] / weight.mean()

# show the error distribution
plt.figure(figsize=(10, 5))
plt.hist(err.cpu().numpy(), bins=100, density=True)
plt.title(f"Symmetrical Epipolar Distance Distribution of {len(err)} correspondences, average: {err.mean().item():.3f}, std: {err.std().item():.3f}, search_radius: {search_radius}")
plt.xlabel("Distance")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

err_rgb = torch.norm(sampled_i - sampled_j, dim=-1)

# show the error distribution
plt.figure(figsize=(10, 5))
plt.hist(err_rgb.cpu().numpy(), bins=100, density=True)
plt.title(f"RGB Distance Distribution of {len(err_rgb)} correspondences, average: {err_rgb.mean().item():.3f},  std: {err_rgb.std().item():.3f}, search_radius: {search_radius}")
plt.xlabel("Distance")
plt.ylabel("Density")
plt.tight_layout()
plt.show()

100%|██████████| 19/19 [00:00<00:00, 53.93it/s]
100%|██████████| 131/131 [00:01<00:00, 129.06it/s]
  plt.tight_layout()
  plt.tight_layout()


In [21]:
save_dir = sparse_dir_pred.split('/sparse/')[0]
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

np.save(os.path.join(save_dir, f'corr_s{search_radius}.npy'), {
    'corr_points_i_normalized': corr_points_i_normalized_.cpu().numpy(),
    'corr_points_j_normalized': corr_points_j_normalized.cpu().numpy(),
    'image_names_i': image_names_i,
    'image_names_j': image_names_j,
})

print(f"Point correspondences saved to {os.path.join(save_dir, f'corr_s{search_radius}.npy')}.")

Point correspondences saved to ../data/MipNeRF360_vggt/treehill/corr_s3.npy.
