In [1]:
from transformers import AutoImageProcessor, SuperPointForKeypointDetection

processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")

In [34]:
from PIL import Image
import requests

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image1 = Image.open(requests.get(url, stream=True).raw)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image2 = Image.open(requests.get(url, stream=True).raw)

images = [image1, image2]

In [35]:
inputs = processor(images, return_tensors="pt")

In [36]:
outputs = model(**inputs)

In [46]:
import torch
import cv2


def extract_keypoints(outputs, index):
    image_mask = outputs.mask[index]
    indices = torch.nonzero(image_mask).squeeze()

    # Ensure indices are not empty
    if indices.numel() == 0:
        return [], None, None  # Handle case with no keypoints

    keypoints = outputs.keypoints[index][indices]
    descriptors = outputs.descriptors[index][indices]  
    scores = outputs.scores[index][indices]

    # Convert keypoints to a format compatible with OpenCV
    keypoints_cv = []
    for i, keypoint in enumerate(keypoints):
        kp = cv2.KeyPoint(x=keypoint[0].item(), y=keypoint[1].item(), size=1, response=scores[i].item())
        keypoints_cv.append(kp)

    descriptors = descriptors.detach().cpu().numpy()
    print(type(descriptors))
    print(descriptors.dtype)
    print(descriptors.shape)

    return keypoints_cv, descriptors, scores

In [47]:
x = extract_keypoints(outputs, 0)
y = extract_keypoints(outputs, 1)

<class 'numpy.ndarray'>
float32
(568, 256)
<class 'numpy.ndarray'>
float32
(568, 256)


In [48]:
keypoints_cv, descriptors, scores = x
keypoints_cv2, descriptors2, scores2 = y

In [49]:
FLANN_INDEX_LSH = 6
index_params = dict(
    algorithm=FLANN_INDEX_LSH,
    table_number=6,
    key_size=12,
    multi_probe_level=1
)
search_params = dict(checks=50)

flann = cv2.FlannBasedMatcher(
    indexParams=index_params,
    searchParams=search_params
)

In [54]:
all_matches = flann.knnMatch(descriptors, descriptors2, k=2)

In [53]:
descriptors = descriptors.astype('uint8')
type(descriptors)

descriptors2 = descriptors2.astype('uint8')
type(descriptors2)

numpy.ndarray

In [55]:
import numpy as np

In [56]:
# Apply ratio test to filter matches
matches = []

for x in all_matches:
    if len(x) != 2:
        continue

    m, n = x
    if m.distance < 0.7 * n.distance:
        matches.append(m)

# Sort matches based on distance
matches = sorted(matches, key=lambda x: x.distance)

# Extract matched points from both images
q1 = np.float32([keypoints_cv[m.queryIdx].pt for m in matches])
q2 = np.float32([keypoints_cv2[m.trainIdx].pt for m in matches])
