#XFeat minimal inference example

## Test on simple input (sparse setting)

In [1]:
import numpy as np
import os
import torch
import tqdm

from modules.xfeat import XFeat

xfeat = XFeat()

#Random input
x = torch.randn(1,3,480,640)

#Simple inference with batch = 1
output = xfeat.detectAndCompute(x, top_k = 4096)[0]
print("----------------")
print("keypoints: ", output['keypoints'].shape)
print("descriptors: ", output['descriptors'].shape)
print("scores: ", output['scores'].shape)
print("----------------\n")

loading weights from: /Users/farrelsalim/Documents/vebits-intern/xfeat-matching/modules/../weights/xfeat.pt
----------------
keypoints:  torch.Size([4096, 2])
descriptors:  torch.Size([4096, 64])
scores:  torch.Size([4096])
----------------



In [4]:
xs = np.array([x for x, y in output['keypoints']])
ys = np.array([y for x, y in output['keypoints']])

print(np.min(xs), np.min(ys), np.max(xs), np.max(ys))
print(x.shape)

0.0 0.0 638.0 478.0
torch.Size([1, 3, 480, 640])


## Stress test to check FPS on VGA (sparse setting)

In [5]:
x = torch.randn(1,3,480,640)
# Stress test
for i in tqdm.tqdm(range(100), desc="Stress test on VGA resolution"):
	output = xfeat.detectAndCompute(x, top_k = 4096)


Stress test on VGA resolution: 100%|██████████| 100/100 [00:13<00:00,  7.15it/s]


## Test with batched mode (sparse)

In [6]:
# Batched mode
x = torch.randn(4,3,480,640)
outputs = xfeat.detectAndCompute(x, top_k = 4096)
print("# detected features on each batch item:", [len(o['keypoints']) for o in outputs])

# detected features on each batch item: [4096, 4096, 4096, 4096]


## Matches two images with built-in MNN matcher (sparse mode)

In [7]:
# Match two images with sparse features
x1 = torch.randn(1,3,480,640)
x2 = torch.randn(1,3,480,640)
mkpts_0, mkpts_1 = xfeat.match_xfeat(x1, x2)

## Matches two images with semi-dense matching, and batched mode (batch size = 4) for demonstration purpose

In [8]:
# Create 4 image pairs
x1 = torch.randn(4,3,480,640)
x2 = torch.randn(4,3,480,640)

#Obtain matches for each batch item
matches_list = xfeat.match_xfeat_star(x1, x2, top_k = 5000)
print('number of img pairs', len(matches_list))
print(matches_list[0].shape) # -> output is (x1,y1,x2,y2)

number of img pairs 4
torch.Size([160, 4])
