In [None]:
!ls /kaggle/input/
!pip install kornia --no-index --find-links=file:///kaggle/input/imc2022-dependencies/pip/kornia/ --upgrade 
!pip install kornia_moons --no-index --find-links=file:///kaggle/input/imc2022-dependencies/pip/kornia_moons/ --no-deps  --upgrade 
print('Done!')

In [None]:
!pip install ../input/loftrutils/einops-0.4.1-py3-none-any.whl

In [None]:
!cp -r ../input/imutils/imutils-0.5.3/ /
!pip install /imutils-0.5.3/

In [None]:
import sys
sys.path.append('../input/loftrutils/LoFTR-master/LoFTR-master/')

In [None]:
import cv2
import kornia as K
import kornia.feature as KF
from kornia.feature.loftr import LoFTR
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import glob
import random
from tqdm.notebook import tqdm
from PIL import Image
import matplotlib.cm as cm

from kornia_moons.feature import *

# from src.loftr import LoFTR, default_cfg
from src.utils.plotting import make_matching_figure

In [None]:
DEVICE = 'cuda:0'
WEIGHT_PATH = '../input/loftrutils/outdoor_ds.ckpt'
# LONGEST_EDGE = 1500

In [None]:
matcher = LoFTR(pretrained=None)
matcher.load_state_dict(torch.load(WEIGHT_PATH)['state_dict'])
matcher = matcher.to(DEVICE)
matcher.eval()
print()

In [None]:
import csv

src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]

In [None]:
def plot_images(ims):
    
    fig, axes = plt.subplots(3, 3, figsize=(20,20))
    
    for idx, img in enumerate(ims):
        i = idx % 3 
        j = idx // 3 
        image = Image.open(img)
        image = image.resize((300,300))
        axes[i, j].imshow(image)
        axes[i, j].set_title(img.split('/')[-1])

    plt.subplots_adjust(wspace=0, hspace=.2)
    plt.show()

In [None]:
# def load_image(fname, target_size=(640, 480)):
#     img_raw = cv2.imread(fname, cv2.IMREAD_GRAYSCALE)
#     img_raw = cv2.resize(img_raw, target_size)
#     return img_raw

# def to_tensor(img_raw):
#     img = torch.from_numpy(img_raw)[None][None].cuda() / 255.
#     return img

import imutils
def resize_keep_ratio(img, longest_size=1500):
    height, width = img.shape[:2]
    if height >= width:
        resized_img = imutils.resize(img, height=longest_size)
    else:
        resized_img = imutils.resize(img, width=longest_size)
    return resized_img

def resize(img, target_size=(640,480)):
    resized_img = cv2.resize(img, target_size)
    return resized_img

def load_torch_image(fname):
    img = cv2.imread(fname)
#     img = resize_keep_ratio(img)
#     img = cv2.resize(img, (img.shape[1]//8*8, img.shape[0]//8*8))  # input size should be divisible by 8
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img

In [None]:
def match(img_path0, img_path1, matcher, device=DEVICE):
    img0 = load_torch_image(img_path0)
    img1 = load_torch_image(img_path1)
        
    input_dict = {"image0": K.color.rgb_to_grayscale(img0).to(device), 
                  "image1": K.color.rgb_to_grayscale(img1).to(device)}
    
    with torch.no_grad():
        correspondences = matcher(input_dict)
        
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
        
    return mkpts0, mkpts1
        
def get_F_matrix(mkpts0, mkpts1):

    # Make sure we do not trigger an exception here.
    if len(mkpts0) > 8:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.5, 0.999, 100000)

        assert F.shape == (3, 3), 'Malformed F?'
    else:
        F = np.zeros((3, 3))

    return F

In [None]:
def match_and_draw(img_path0, img_path1, matcher, device=DEVICE, drop_outliers=False):
    
    img0 = load_torch_image(img_path0)
    img1 = load_torch_image(img_path1)
    
    print(img0.shape, img1.shape)
    
    input_dict = {"image0": K.color.rgb_to_grayscale(img0).to(device), 
                  "image1": K.color.rgb_to_grayscale(img1).to(device)}
    
    with torch.no_grad():
        correspondences = matcher(input_dict)
        
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    
    if len(mkpts0) > 8:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.5, 0.999, 100000)

        assert F.shape == (3, 3), 'Malformed F?'
    else:
        F = np.zeros((3, 3))
            
    if drop_outliers:
        print(len(mkpts0))
        mkpts0 = mkpts0[inliers.reshape(-1) > 0]
        mkpts1 = mkpts1[inliers.reshape(-1) > 0]
        inliers = inliers[inliers > 0]
    
        print(len(mkpts0))
    
    draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(img0),
        K.tensor_to_image(img1),
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})
    
    del correspondences, input_dict
    torch.cuda.empty_cache()

In [None]:
def plot_matching(samples, files):
    for i in range(samples.shape[1]):
        path0 = files[samples[0][i]]
        path1 = files[samples[1][i]]
        print(f'Matching: {path0} to {path1}')
        match_and_draw(path0, path1, matcher)
        plt.show()

In [None]:
%%time
print('Ploting on sample test set')
for i, row in tqdm(enumerate(test_samples), total=len(test_samples)):
    print(i)
    sample_id, batch_id, image_0_id, image_1_id = row
    img_path0 = f'{src}/test_images/{batch_id}/{image_0_id}.png'
    img_path1 = f'{src}/test_images/{batch_id}/{image_1_id}.png'
    match_and_draw(img_path0, img_path1, matcher)    
    
    if i >= 3:
        break

# Run submission

In [None]:
%%time
F_dict = {}
for i, row in tqdm(enumerate(test_samples), total=len(test_samples)):
    sample_id, batch_id, image_0_id, image_1_id = row
    img_path0 = f'{src}/test_images/{batch_id}/{image_0_id}.png'
    img_path1 = f'{src}/test_images/{batch_id}/{image_1_id}.png'
    mkpts0, mkpts1 = match(img_path0, img_path1, matcher)
    F_dict[sample_id] = get_F_matrix(mkpts0, mkpts1)
    

In [None]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

In [None]:
with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

In [None]:
pd.read_csv('submission.csv')