# SiLK Submission

This is the official IMC2022 results of the [SiLK](https://github.com/facebookresearch/silk) keypoint model.

In [1]:
dry_run = False

In [2]:
!pip uninstall -y torchtext torchaudio fastai allennlp

# offline
!pip install -f /kaggle/input/imc2022-dependencies-silk/wheels --no-index torch==1.11.0+cu113
!pip install -f /kaggle/input/imc2022-dependencies-silk/wheels --no-index torchvision==0.12.0+cu113
!pip install -f /kaggle/input/imc2022-dependencies-silk/wheels --no-index hydra-core
!pip install -f /kaggle/input/imc2022-dependencies-silk/wheels --no-index loguru

Found existing installation: torchtext 0.14.0
Uninstalling torchtext-0.14.0:
  Successfully uninstalled torchtext-0.14.0
Found existing installation: torchaudio 0.13.0
Uninstalling torchaudio-0.13.0:
  Successfully uninstalled torchaudio-0.13.0
Found existing installation: fastai 2.7.12
Uninstalling fastai-2.7.12:
  Successfully uninstalled fastai-2.7.12
[0mLooking in links: /kaggle/input/imc2022-dependencies-silk/wheels
Processing /kaggle/input/imc2022-dependencies-silk/wheels/torch-1.11.0+cu113-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torch
  Attempting uninstall: torch
    Found existing installation: torch 1.13.0
    Uninstalling torch-1.13.0:
      Successfully uninstalled torch-1.13.0
Successfully installed torch-1.11.0+cu113
[0mLooking in links: /kaggle/input/imc2022-dependencies-silk/wheels
Processing /kaggle/input/imc2022-dependencies-silk/wheels/torchvision-0.12.0+cu113-cp37-cp37m-linux_x86_64.whl
Installing collected packages: torchvisi

In [3]:
import os
import sys
import numpy as np
import torch
import h5py
import matplotlib.pyplot as plt
from glob import glob
import csv
import cv2
from functools import partial
import random
from torchvision.transforms.functional import resize, InterpolationMode

sys.path.append(os.path.join('/kaggle/input/silk-kp/silk-main/silk-main'))
sys.path.append(os.path.join('/kaggle/input/silk-kp/silk-main/silk-main/scripts/examples'))

import silk
import common
from silk.backbones.silk.silk import from_feature_coords_to_image_coords

conf = {
    "common": {
        "paths": {
            "images": "/kaggle/input/image-matching-challenge-2022/test_images/",
        },
        "nms": 0, # 0 = disabled
        "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
        "topk": 30_000,
        "ransac": {
            "max_iter": 200_000,
            "confidence": 0.99999,
            "reproj_threshold": 0.25,
        }
    },
    "model": {
        "checkpoint": "/kaggle/input/silk-kp/coco-rgb-aug.ckpt",
        "image_max_side": 720,
        "matcher": {
            "postprocessing": "double-softmax",
            "threshold": 0.99,
            "temperature": 0.1,
        }
    },
}

common.DEVICE = conf["common"]["device"] # hacky


In [4]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def BuildCompositeImage(im1, im2, axis=1, margin=0, background=1):
    '''Convenience function to stack two images with different sizes.'''
    
    if background != 0 and background != 1:
        background = 1
    if axis != 0 and axis != 1:
        raise RuntimeError('Axis must be 0 (vertical) or 1 (horizontal')

    h1, w1, _ = im1.shape
    h2, w2, _ = im2.shape

    if axis == 1:
        composite = np.zeros((max(h1, h2), w1 + w2 + margin, 3), dtype=np.uint8) + 255 * background
        if h1 > h2:
            voff1, voff2 = 0, (h1 - h2) // 2
        else:
            voff1, voff2 = (h2 - h1) // 2, 0
        hoff1, hoff2 = 0, w1 + margin
    else:
        composite = np.zeros((h1 + h2 + margin, max(w1, w2), 3), dtype=np.uint8) + 255 * background
        if w1 > w2:
            hoff1, hoff2 = 0, (w1 - w2) // 2
        else:
            hoff1, hoff2 = (w2 - w1) // 2, 0
        voff1, voff2 = 0, h1 + margin
    composite[voff1:voff1 + h1, hoff1:hoff1 + w1, :] = im1
    composite[voff2:voff2 + h2, hoff2:hoff2 + w2, :] = im2

    return (composite, (voff1, voff2), (hoff1, hoff2))


def DrawMatches(im1, im2, kp1, kp2, matches, axis=1, margin=0, background=0, linewidth=2):
    '''Draw keypoints and matches.'''
    
    composite, v_offset, h_offset = BuildCompositeImage(im1, im2, axis, margin, background)

    # Draw all keypoints.
    for coord_a, coord_b in zip(kp1, kp2):
        composite = cv2.drawMarker(composite, (int(coord_a[0] + h_offset[0]), int(coord_a[1] + v_offset[0])), color=(255, 0, 0), markerType=cv2.MARKER_CROSS, markerSize=5, thickness=1)
        composite = cv2.drawMarker(composite, (int(coord_b[0] + h_offset[1]), int(coord_b[1] + v_offset[1])), color=(255, 0, 0), markerType=cv2.MARKER_CROSS, markerSize=5, thickness=1)
    
    # Draw matches, and highlight keypoints used in matches.
    for idx_a, idx_b in matches:
        composite = cv2.drawMarker(composite, (int(kp1[idx_a, 0] + h_offset[0]), int(kp1[idx_a, 1] + v_offset[0])), color=(0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=12, thickness=1)
        composite = cv2.drawMarker(composite, (int(kp2[idx_b, 0] + h_offset[1]), int(kp2[idx_b, 1] + v_offset[1])), color=(0, 0, 255), markerType=cv2.MARKER_CROSS, markerSize=12, thickness=1)
#         color = (random.randint(0, 255), random.randint(0, 255), random.randint(0, 255))
        color = (0, 0, 255)
        composite = cv2.line(composite,
                             tuple([int(kp1[idx_a][0] + h_offset[0]),
                                   int(kp1[idx_a][1] + v_offset[0])]),
                             tuple([int(kp2[idx_b][0] + h_offset[1]),
                                   int(kp2[idx_b][1] + v_offset[1])]), color=color, thickness=1)
    
    return composite

In [5]:
# Read the pairs file.
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]

if dry_run:
    for sample in test_samples:
        print(sample)

In [6]:
def run_silk(model, image, max_size = 1024):
    assert image.shape[0] == 1
    
    # get scaling factor
    if max_size is not None:
        scale = max(max(image.shape[-2:]) / max_size, 1.)
    else:
        scale = 1.

    # resize image if necessary
    if scale != 1.:
        image = resize(image, size = (int(image.shape[2] / scale), int(image.shape[3] / scale)), interpolation=InterpolationMode.BILINEAR)

    # to greyscale if necessary
    if image.shape[1] == 3:
        image = torchvision.transforms.functional.rgb_to_grayscale(image)

    keypoints, descriptors, prob = model(image)
    keypoints = from_feature_coords_to_image_coords(model, keypoints)
    descriptors = descriptors.reshape(1, 128, -1).permute(0, 2, 1)

    keypoints = keypoints[0] * scale
    descriptors = descriptors[0]
    prob = prob[0]
    
    return keypoints, descriptors, prob

def get_top_k(keypoints, descriptors, k):
    positions, scores = keypoints[:,:2], keypoints[:,2]

    # top-k selection
    idxs = scores.argsort()[-k:]

    return positions[idxs], descriptors[idxs] / 1.41, scores[idxs]

def extract(model, max_side, top_k, *images):   
    all_positions = []
    all_descriptors = []
    for image in images:
        positions, descriptors, _ = run_silk(model, image, max_size = max_side)
        positions, descriptors, scores = get_top_k(positions, descriptors, k = top_k)
        
        positions = positions[:,[1,0]]
        descriptors = descriptors * 1.41
        
        all_positions.append(positions)
        all_descriptors.append(descriptors)
        
    return all_positions, all_descriptors

def inlier_matches(inlier_mask, matches):
    if inlier_mask is not None:
        matches_after_ransac = np.array([match for match, is_inlier in zip(matches, inlier_mask) if is_inlier])
    else:
        matches_after_ransac = np.array([])
        
    return matches_after_ransac
    

def display(batch_id, image_1_id, image_2_id, positions_1, positions_2, matches, *crops):
    src = conf["common"]["paths"]["images"]
    image_1 = cv2.cvtColor(cv2.imread(f'{src}/{batch_id}/{image_1_id}.png'), cv2.COLOR_BGR2RGB)
    image_2 = cv2.cvtColor(cv2.imread(f'{src}/{batch_id}/{image_2_id}.png'), cv2.COLOR_BGR2RGB)

    im_inliers = DrawMatches(image_1, image_2, positions_1, positions_2, matches)
    
    for crop in crops:
        x, y, dx, dy = crop
        color = (0, 255, 0)
        im_inliers = cv2.line(im_inliers,tuple([int(x), int(y)]),tuple([int(x+dx), int(y)]),color=color,thickness=1)
        im_inliers = cv2.line(im_inliers,tuple([int(x), int(y)]),tuple([int(x), int(y+dy)]),color=color,thickness=1)
        im_inliers = cv2.line(im_inliers,tuple([int(x+dx), int(y)]),tuple([int(x+dx), int(y+dy)]),color=color,thickness=1)
        im_inliers = cv2.line(im_inliers,tuple([int(x), int(y+dy)]),tuple([int(x+dx), int(y+dy)]),color=color,thickness=1)
    
    fig = plt.figure(figsize=(15, 15))
    plt.title(f'{image_1_id}-{image_2_id}')
    plt.imshow(im_inliers)
    plt.axis('off')
    plt.show()
    
def write_submission(F_dict):
    with open('submission.csv', 'w') as f:
        f.write('sample_id,fundamental_matrix\n')
        for sample_id, F in F_dict.items():
            f.write(f'{sample_id},{FlattenMatrix(F)}\n')

    if dry_run:
        !cat submission.csv

In [7]:
# load model
model = common.get_model(
    checkpoint=conf["model"]["checkpoint"],
    nms=conf["common"]["nms"],
    device=conf["common"]["device"],
)

# create matcher
matcher = silk.models.silk.matcher(
    postprocessing=conf["model"]["matcher"]["postprocessing"],
    threshold=conf["model"]["matcher"]["threshold"],
    temperature=conf["model"]["matcher"]["temperature"],
)

F_dict = {}
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    print(f"process {sample_id}")

    batch_path = os.path.join(conf["common"]["paths"]["images"], batch_id)
    image_1_path = os.path.join(batch_path, f"{image_1_id}.png")
    image_2_path = os.path.join(batch_path, f"{image_2_id}.png")

    # load
    images_1 = common.load_images(image_1_path)
    images_2 = common.load_images(image_2_path)

    # extract
    (positions_1, positions_2), (descriptors_1, descriptors_2) = extract(
        model,
        conf["model"]["image_max_side"],
        conf["common"]["topk"],
        images_1,
        images_2,
    )

    # match
    matches = matcher(descriptors_1, descriptors_2)
    matches = matches.cpu().numpy()

    # estimate fundamental matrix
    cur_F, inlier_mask = cv2.findFundamentalMat(
        positions_1.cpu().numpy()[matches[:, 0]],
        positions_2.cpu().numpy()[matches[:, 1]],
        cv2.USAC_MAGSAC,
        ransacReprojThreshold=conf["common"]["ransac"]["reproj_threshold"],
        confidence=conf["common"]["ransac"]["confidence"],
        maxIters=conf["common"]["ransac"]["max_iter"],
    )

    print(f"n matches:{matches.shape[0]}")
    print(f"n inliers:{inlier_mask.sum()}")

    F_dict[sample_id] = cur_F

    # optional display
    if dry_run:
        display(batch_id, image_1_id, image_2_id, positions_1.cpu().numpy(), positions_2.cpu().numpy(), inlier_matches(inlier_mask, matches))

write_submission(F_dict)

  "`pytorch_lightning.utilities.cloud_io.load` has been deprecated in v1.8.0 and will be"


process googleurban;1cf87530;a5a9975574c94ff9a285f58c39b53d2c-0143f47ee9e54243a1b8454f3e91621a
n matches:3301
n inliers:720
process googleurban;6ceaefff;39563e58b2b7411da3f06427c9ee4239-0303b05ca0cb46959eac430e4b2472ca
n matches:3559
n inliers:637
process googleurban;d91db836;81dd07fb7b9a4e01996cee637f91ca1a-0006b1337a0347f49b4e651c035dfa0e
n matches:1673
n inliers:61
