读取图像

In [1]:
import numpy as np
import os
import torch
import tqdm
import cv2
import matplotlib.pyplot as plt

from modules.xfeat import XFeat

# 初始化模型
xfeat = XFeat()

# 加载图像
im1 = cv2.imread('./images/t_1.jpg')
im2 = cv2.imread('./images/t_2.jpg')
im3 = cv2.imread('./images/t_3.jpg')

im1.shape

loading weights from: /home/ubuntu/workspaces/xfeat/modules/../weights/xfeat.pt


(4096, 3072, 3)

In [2]:
# 图像直接匹配
kp1, kp2 = xfeat.match_xfeat_star(im1, im2, top_k=4096)



In [3]:
# 计算每对匹配点之间的欧氏距离
# kp1, kp2 形状均为 (N, 2)
distances = abs(kp1[:, 1] - kp2[:, 1])

# 获取距离小于 10 像素的布尔掩码
mask = distances < 400

# 筛选出符合条件的点对
valid_kp1 = kp1[mask]
valid_kp2 = kp2[mask]

In [4]:
distances

array([ 82.81763 ,  66.4978  ,  67.379395, ...,  87.38306 , 105.24902 ,
        23.199463], shape=(1278,), dtype=float32)

In [5]:



import cv2
import numpy as np

def visualize_direct_matches(im1, im2, kp1, kp2, window_name="XFeat Direct Matches"):
    # 1. 确保 kp1 和 kp2 是 numpy 数组且类型正确
    # 如果 kp1 是 Tensor，需要 kp1.cpu().numpy()
    pts1 = np.array(kp1, dtype=np.float32)
    pts2 = np.array(kp2, dtype=np.float32)

    # 2. 横向拼接图像
    h1, w1 = im1.shape[:2]
    h2, w2 = im2.shape[:2]
    canvas = np.zeros((max(h1, h2), w1 + w2, 3), dtype=np.uint8)
    canvas[:h1, :w1] = im1
    canvas[:h2, w1:] = im2

    # 3. 绘制匹配线
    # 如果点太多（4096个），建议随机抽样显示，否则画面太乱
    num_matches = len(kp1)
    display_count = num_matches
    indices = np.random.choice(num_matches, display_count, replace=False)

    for i in indices:
        # 生成随机颜色
        color = np.random.randint(0, 255, (3,)).tolist()
        
        # 坐标转换：图2的点需要加上图1的宽度偏移 w1
        pt1 = (int(pts1[i][0]), int(pts1[i][1]))
        pt2 = (int(pts2[i][0] + w1), int(pts2[i][1]))
        
        cv2.line(canvas, pt1, pt2, color, 4, cv2.LINE_AA)
        cv2.circle(canvas, pt1, 3, color, -1)
        cv2.circle(canvas, pt2, 3, color, -1)

    # 4. 缩放显示窗口
    max_w = 1280
    if canvas.shape[1] > max_w:
        scale = max_w / canvas.shape[1]
        canvas = cv2.resize(canvas, (max_w, int(canvas.shape[0] * scale)))

    cv2.imshow(window_name, canvas)
    cv2.waitKey(0)
    cv2.destroyAllWindows()

# 调用
# kp1, kp2 = xfeat.match_xfeat(im1, im2, top_k=4096)
visualize_direct_matches(im1, im2, valid_kp1, valid_kp2)


关键点检测以及特征匹配

In [6]:
top_k = 4096

# 这个支持批量模式， 对于每一个图像进行特征提取
r_1 = xfeat.detectAndCompute(im1, top_k = top_k)[0]
r_2 = xfeat.detectAndCompute(im2, top_k = top_k)[0]
r_3 = xfeat.detectAndCompute(im3, top_k = top_k)[0]


输出属性

In [7]:
print("返回的数据类型是：", type(r_1), "共计元素：", len(r_1))
for key in r_1.keys():
    print(key, "\t", r_1[key].shape)

返回的数据类型是： <class 'dict'> 共计元素： 3
keypoints 	 torch.Size([4096, 2])
scores 	 torch.Size([4096])
descriptors 	 torch.Size([4096, 64])


In [8]:
index_1_2, index_2_1 = xfeat.match(r_1["descriptors"], r_2["descriptors"], min_cossim=0.1)
index_2_3, index_3_2 = xfeat.match(r_2["descriptors"], r_3["descriptors"], min_cossim=0.1)

描述子匹配

In [9]:
import cv2
import numpy as np

def get_match_visualization(im_a, im_b, r_a, r_b, idx_a, idx_b):
    # 1. 关键点坐标转换：强制转为 float 并确保是 CPU 上的数据
    # 解决 "Argument 'x' can't be treated as a float" 报错
    kp_a = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in r_a["keypoints"]]
    kp_b = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in r_b["keypoints"]]

    # 2. 构造 DMatch (queryIdx 为图A，trainIdx 为图B)
    matches = [cv2.DMatch(int(ia), int(ib), 0) for ia, ib in zip(idx_a, idx_b)]

    # 3. 绘制匹配关系
    # 使用随机颜色线条，不绘制单点
    res = cv2.drawMatches(im_a, kp_a, im_b, kp_b, matches, None, 
                          flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS)
    return res

# --- 生成两组匹配图 ---
res12 = get_match_visualization(im1, im2, r_1, r_2, index_1_2, index_2_1)
res23 = get_match_visualization(im2, im3, r_2, r_3, index_2_3, index_3_2)

# --- 拼接与显示 ---
# 确保两个图宽度一致以便拼接（如果原图大小不同，这一步很重要）
if res12.shape[1] != res23.shape[1]:
    target_w = res12.shape[1]
    scale = target_w / res23.shape[1]
    res23 = cv2.resize(res23, (target_w, int(res23.shape[0] * scale)))

# 垂直拼接 (Vertical Stack)
combined_res = np.vstack((res12, res23))

# --- 限制窗口显示大小 ---
max_display_height = 900  # 设定最大显示高度
h, w = combined_res.shape[:2]
if h > max_display_height:
    scale = max_display_height / h
    combined_res = cv2.resize(combined_res, (int(w * scale), max_display_height), interpolation=cv2.INTER_AREA)

cv2.imshow("Matches 1-2 (Top) and 2-3 (Bottom)", combined_res)
cv2.waitKey(0)
cv2.destroyAllWindows()

In [10]:


# 寻找图像2中两次匹配中，关键点均出现了的index
common_idx_in_2, mask_12, mask_23 = np.intersect1d( index_2_1, index_2_3, return_indices=True )
indices_1 = index_1_2[mask_12]
indices_2 = common_idx_in_2
indices_3 = index_3_2[mask_23]


uv_1 = r_1["keypoints"][indices_1]
uv_2 = r_2["keypoints"][indices_2]
uv_3 = r_3["keypoints"][indices_3]

In [11]:
uv_1.shape, uv_2.shape, uv_3.shape

(torch.Size([726, 2]), torch.Size([726, 2]), torch.Size([726, 2]))

可视化三幅图的特征匹配结果

In [12]:
def visualize_triple_matches(im1, im2, im3, uv_1, uv_2, uv_3):
    # 1. 准备画布：将三张图横向拼接
    h1, w1 = im1.shape[:2]
    h2, w2 = im2.shape[:2]
    h3, w3 = im3.shape[:2]
    
    h_max = max(h1, h2, h3)
    w_total = w1 + w2 + w3
    
    # 创建空画布
    canvas = np.zeros((h_max, w_total, 3), dtype=np.uint8)
    canvas[:h1, :w1] = im1
    canvas[:h2, w1:w1+w2] = im2
    canvas[:h3, w1+w2:] = im3

    # 2. 计算每一帧在画布上的横向偏移量
    offset1 = 0
    offset2 = w1
    offset3 = w1 + w2

    # 3. 绘制连线
    # 为每一个 track 生成一种随机颜色
    np.random.seed(42) # 固定随机种子使颜色稳定
    colors = np.random.randint(0, 255, (len(uv_1), 3)).tolist()

    for i in range(len(uv_1)):
        color = tuple(colors[i])
        
        # 坐标转换：需要加上偏移量 (x + offset, y)
        pt1 = (int(uv_1[i][0] + offset1), int(uv_1[i][1]))
        pt2 = (int(uv_2[i][0] + offset2), int(uv_2[i][1]))
        pt3 = (int(uv_3[i][0] + offset3), int(uv_3[i][1]))

        # 绘制圆点
        cv2.circle(canvas, pt1, 4, color, -1)
        cv2.circle(canvas, pt2, 4, color, -1)
        cv2.circle(canvas, pt3, 4, color, -1)

        # 绘制连接线：1->2 和 2->3
        cv2.line(canvas, pt1, pt2, color, 4, cv2.LINE_AA)
        cv2.line(canvas, pt2, pt3, color, 4, cv2.LINE_AA)

    return canvas

In [13]:
# 1. 生成结果大图
result_img = visualize_triple_matches(im1, im2, im3, uv_1, uv_2, uv_3)

# 2. 计算缩放比例 (假设你希望宽度最大为 1280 像素)
max_width = 1280
h, w = result_img.shape[:2]

if w > max_width:
    scale = max_width / w
    new_w = int(w * scale)
    new_h = int(h * scale)
    # 物理缩放图像
    show_img = cv2.resize(result_img, (new_w, new_h), interpolation=cv2.INTER_AREA)
else:
    show_img = result_img

# 3. 正常显示缩放后的图
cv2.imshow("Scaled Triple Matches", show_img)
cv2.waitKey(0)
cv2.destroyAllWindows()