In [None]:
!ls /kaggle/input/k/oldufo

In [None]:
!pip install kornia --no-index --find-links=file:///kaggle/input/k/oldufo/imc2022-dependencies/pip/kornia/ --upgrade 
!pip install kornia_moons --no-index --find-links=file:///kaggle/input/k/oldufo/imc2022-dependencies/pip/kornia_moons/ --no-deps  --upgrade 

In [None]:
import os
import numpy as np
import cv2
import csv
from glob import glob
import torch
import matplotlib.pyplot as plt
import kornia
import kornia.feature as kornia_feature
from kornia_moons.feature import *

In [None]:
# check cuda
torch.cuda.is_available()

In [None]:
# Kornia uses torch.hub to donwload pretrained models:
# https://github.com/kornia/kornia/blob/f967e7529dd85db263484ae7ccc027784c4ecaf0/kornia/feature/hardnet.py#L147

# Unfortunately, torch.hub attempts to download them even if we place the models at the location they are expected, which would fail when the notebook is submitted:
# https://github.com/pytorch/fairseq/issues/3105

# To work around this we can copy-paste a version of the Kornia KeyNetAffNetHardNet model and hardcode the location of the weights:
# https://github.com/kornia/kornia/blob/fc4b2fb7e7b0a8abbc802e67e8406a5a32c32d62/kornia/feature/integrated.py#L192.

class KeyNetAffNetHardNet(kornia_feature.LocalFeature):
    """Convenience module, which implements KeyNet detector + AffNet + HardNet descriptor."""
    def __init__(self,
                 num_features: int = 5000,
                 upright: bool = True,
                 device: torch.device = torch.device('cuda')):
        ori_module = kornia_feature.PassLAF()
        detector = kornia_feature.KeyNetDetector(False,
                                  ori_module=ori_module,
                                  aff_module=kornia_feature.LAFAffNetShapeEstimator(False).eval()).to(device)
        detector.model.load_state_dict(torch.load('/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/keynet_pytorch.pth')['state_dict'])
        detector.aff.load_state_dict(torch.load('/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/AffNet.pth')['state_dict'])
        descriptor = kornia_feature.LAFDescriptor(kornia_feature.HardNet8(False),
                                   patch_size=32,
                                   grayscale_descriptor=True).to(device)
        descriptor.descriptor.load_state_dict(torch.load('/kaggle/input/k/oldufo/imc2022-dependencies/pretrained/hardnet8v2.pt'))
        super().__init__(detector, descriptor)

In [None]:
# Read the pairs file.

src = '/kaggle/input/image-matching-challenge-2022/'

test_samples = []
with open(f'{src}/test.csv') as f:
    reader = csv.reader(f, delimiter=',')
    for i, row in enumerate(reader):
        # Skip header.
        if i == 0:
            continue
        test_samples += [row]

In [None]:
for sample in test_samples:
    print(sample)

In [None]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


help(draw_LAF_matches)

# We will draw only inliers and tentative matches:
draw_dict={
    'inlier_color': (0.2, 1, 0.2),  # Green: inliers.
    'tentative_color': (1, 1, 0.2, 0.5),  #Light yellow: tentative matches.
    'feature_color': None,
    'vertical': False
}

In [None]:
import gc
num_features = 5000

# Compute this many samples, and fill the rest with random values, to generate a quick submission and check it works without waiting for a full run. Set to -1 to use all samples.
# how_many_to_fill = 500
how_many_to_fill = -1

device = torch.device('cuda')
keynet_affnet_hardnet8 = KeyNetAffNetHardNet(num_features).eval()
matcher = kornia_feature.DescriptorMatcher('snn', 0.9)

F_dict = {}
for i, row in enumerate(test_samples):
    sample_id, batch_id, image_1_id, image_2_id = row
    
    if how_many_to_fill >= 0 and i >= how_many_to_fill:
        F_dict[sample_id] = np.random.rand(3, 3)
        continue

    # Load the images.
    image_1 = cv2.cvtColor(cv2.imread(f'{src}/test_images/{batch_id}/{image_1_id}.png'), cv2.COLOR_BGR2RGB)
    image_2 = cv2.cvtColor(cv2.imread(f'{src}/test_images/{batch_id}/{image_2_id}.png'), cv2.COLOR_BGR2RGB)

    # Extract features.
    with torch.no_grad():
        timg1 = kornia.image_to_tensor(image_1, False).float() / 255.
        timg1 = kornia.color.rgb_to_grayscale(timg1).to(device)
        timg2 = kornia.image_to_tensor(image_2, False).float() / 255.
        timg2 = kornia.color.rgb_to_grayscale(timg2).to(device)
        
        lafs1, resps1, descriptors_1 = keynet_affnet_hardnet8(timg1)
        lafs2, resps2, descriptors_2 = keynet_affnet_hardnet8(timg2)

        if descriptors_1.size(1) == 0 or descriptors_2.size(1) == 0:
            F_dict[sample_id] = np.zeros((3, 3))
            continue

        dists, idxs  = matcher(descriptors_1[0], descriptors_2[0])
        cur_kp1 = kornia_feature.get_laf_center(lafs1).detach().cpu().numpy().reshape(-1, 2)
        cur_kp2 = kornia_feature.get_laf_center(lafs2).detach().cpu().numpy().reshape(-1, 2)
        match_idxs = idxs.detach().cpu().numpy()

    # Make sure we do not trigger an exception here.
    if len(match_idxs) > 8:
        F, inlier_mask = cv2.findFundamentalMat(cur_kp1[match_idxs[:, 0]], cur_kp2[match_idxs[:, 1]],
                                                cv2.USAC_MAGSAC,
                                                ransacReprojThreshold=0.5,
                                                confidence=0.99999,
                                                maxIters=10000)
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()

In [None]:
draw_LAF_matches(lafs1.cpu(), lafs2.cpu(),
                 match_idxs, image_1, image_2,
                 inlier_mask=inlier_mask.astype(np.bool), draw_dict=draw_dict)
plt.title(f'{image_1_id}-{image_2_id}')
plt.axis('off')
plt.show()

In [None]:
 for sample_id, F in F_dict.items():
        print(sample_id, F)

In [None]:
with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')

!cat submission.csv