# ***INSTALL LIBS*** 

In [None]:
%%capture
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl
!pip install git+https://github.com/kornia/kornia
!pip install kornia_moons

# ***IMPORT DEPENDENCIES***

In [None]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF
import kornia.feature.loftr as LoFTR
import gc
import pandas as pd
import glob
import random

from PIL import Image


# ***LOAD MODEL***
device = torch.device('cuda')
matcher = KF.LoFTR(pretrained=None)
matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher = matcher.to(device).eval()

In [None]:
device = torch.device('cuda')
matcher = LoFTR.LoFTR(pretrained=None)
#matcher.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
matcher.load_state_dict(torch.load("../input/trained/last.ckpt")['state_dict'])
matcher = matcher.to(device).eval()

In [None]:
src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]


def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def load_torch_image(fname, device):
    img = cv2.imread(fname)
    scale = 840 / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    img = K.image_to_tensor(img, False).float() /255.
    img = K.color.bgr_to_rgb(img)
    return img.to(device)

In [None]:

F_dict = {}
import time
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    # Load the images.
    st = time.time()
    image_1 = load_torch_image(f'{src}/test_images/{batch_id}/{image_1_id}.png', device)
    image_2 = load_torch_image(f'{src}/test_images/{batch_id}/{image_2_id}.png', device)
    print(image_1.shape)
    input_dict = {"image0": K.color.rgb_to_grayscale(image_1), 
              "image1": K.color.rgb_to_grayscale(image_2)}

    with torch.no_grad():
        correspondences = matcher(input_dict)
        
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    
    if len(mkpts0) > 7:
        F, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.200, 0.9999, 250000)
        inliers = inliers > 0
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    nd = time.time()    
    if (i < 3):
        print("Running time: ", nd - st, " s")
        draw_LAF_matches(
        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

        KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                    torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
        K.tensor_to_image(image_1),
        K.tensor_to_image(image_2),
        inliers,
        draw_dict={'inlier_color': (0.2, 1, 0.2),
                   'tentative_color': None, 
                   'feature_color': (0.2, 0.5, 1), 'vertical': False})

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

In [None]:
def plot_images(ims):
    
    fig, axes = plt.subplots(3, 3, figsize=(20,20))
    
    for idx, img in enumerate(ims):
        i = idx % 3 
        j = idx // 3 
        image = Image.open(img)
        image = image.resize((300,300))
        axes[i, j].imshow(image)
        axes[i, j].set_title(img.split('/')[-1])

    plt.subplots_adjust(wspace=0, hspace=.2)
    plt.show()
    

def match_and_draw(img_in1, img_in2):
    img1 = load_torch_image(img_in1, device)
    img2 = load_torch_image(img_in2, device)


    input_dict = {"image0": K.color.rgb_to_grayscale(img1), 
                  "image1": K.color.rgb_to_grayscale(img2)}
    
    with torch.no_grad():
        correspondences = matcher(input_dict)
    
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    H, inliers = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, 0.5, 0.999, 100000)
    inliers = inliers > 0
    
    draw_LAF_matches(
    KF.laf_from_center_scale_ori(torch.from_numpy(mkpts0).view(1,-1, 2),
                                torch.ones(mkpts0.shape[0]).view(1,-1, 1, 1),
                                torch.ones(mkpts0.shape[0]).view(1,-1, 1)),

    KF.laf_from_center_scale_ori(torch.from_numpy(mkpts1).view(1,-1, 2),
                                torch.ones(mkpts1.shape[0]).view(1,-1, 1, 1),
                                torch.ones(mkpts1.shape[0]).view(1,-1, 1)),
        
        
    torch.arange(mkpts0.shape[0]).view(-1,1).repeat(1,2),
    K.tensor_to_image(img1),
    K.tensor_to_image(img2),
    inliers,
    draw_dict={'inlier_color': (0.2, 1, 0.2),
               'tentative_color': None, 
               'feature_color': (0.2, 0.5, 1), 'vertical': False})
    return correspondences



def plot_matching(samples, files):
    for i in range(samples.shape[1]):
        image_1 = files[samples[0][i]]
        image_2 = files[samples[1][i]]
        print(f'Matching: {image_1} to {image_2}')
        correspondences = match_and_draw(image_1, image_2)

    
    

# ***SIMILARITY CHECKS ***

In [None]:
path =  '../input/image-matching-challenge-2022/train/trevi_fountain/images/'
trevi_fountain = [file for file in glob.glob(f'{path}*.jpg')]

plot_images(random.sample(trevi_fountain, 9))

In [None]:
samples = np.random.randint(len(trevi_fountain), size=(2, 4))

plot_matching(samples, trevi_fountain)

# ***DISSIMILARITY CHECK***

In [None]:
sagrada_familia_path =  '../input/image-matching-challenge-2022/train/sagrada_familia/images/'
sagrada_familia_files = [file for file in glob.glob(f'{sagrada_familia_path}*.jpg')]
samples = np.array([[0],[1]])
files = [sagrada_familia_files[0], trevi_fountain[0]]

plot_matching(samples, files)