This notebook is our final submission for Image Matching Challenge 2022

We referred following notebooks. Thank you for awesome works! \
LoFTR: https://www.kaggle.com/code/cbeaud/imc-2022-kornia-score-0-725 \
SuperGlue: https://www.kaggle.com/code/losveria/superglue-baseline \
QuadTreeAttention: https://www.kaggle.com/code/dschettler8845/quadtree-image-matching-challenge-2022 \
Validation: https://www.kaggle.com/code/namgalielei/loftr-validation-score

## Import Packages

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import cv2
import csv
from glob import glob
import matplotlib.pyplot as plt
import gc
from collections import namedtuple
import random
import time
from tqdm.auto import tqdm
from PIL import Image
import torch
from torchvision import io
from torchvision import transforms as T
from PIL import Image

sys.path.append("../input/imc-utils")
from imc_metric import EvaluateSubmission, ReadCovisibilityData, FlattenMatrix

import warnings
warnings.simplefilter('ignore')

device = torch.device('cuda')

## Configuration

In [None]:
class CFG:
    mode = "test" #"val"
    seed = 2022

    # Image sizes for each model
    longest_imgsize_mf = 840
    longest_imgsize_loftr = 840
    longest_imgsize_qta = 1024
    longest_imgsize_sg = [2000, 1600, 1200]  # We used 2 or 3 image scales for superpoint+superglue
    longest_imgsize_p2p = 2048
    
    # Number of pairs used for F matrix estimation
    max_num_pairs_mf = 700
    max_num_pairs_loftr = 500
    max_num_pairs_qta = 700
    max_num_pairs_sg = 250
    max_num_pairs_p2p = 400
    
    # MAGSAC params
    magsac_thresh = 0.2
    magsac_conf = 0.99999
    magsac_maxiter = 8000
       
    # In the final submission, we use MatchFormer, SuperGlue and QuadTreeAttention
    use_loftr = False
    use_matchformer = True
    use_superglue = True
    use_quadtreeattention = True
    use_segmentation = False
    use_patch2pix = False
    
    validate_scene_id = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
    validate_num_pairs = 10

In [None]:
def seed_torch(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True

seed_torch(seed=CFG.seed)

## Patch2Pix

In [None]:
if CFG.use_patch2pix:
    sys.path.append("../input/transforms3d/transforms3d-0.3.1/")
    sys.path.append("../input/immatch/image-matching-toolbox/")
    sys.path.append('../input/immatch/image-matching-toolbox/third_party/patch2pix/')
    
    from immatch.modules.patch2pix import Patch2Pix, Patch2PixRefined
    
    #method = "patch2pix"
    method = "patch2pix_superglue"
    
    backbone_dict = {
        'resnet34': '/kaggle/input/patch2pix-weights/resnet34-333f7ec4.pth'
    }

    if method == "patch2pix":
        cfg = {
            "ckpt": '../input/patch2pix-weights/patch2pix_pretrained.pth',
            "ksize": 2,
            "imsize": 2048,
            "match_threshold": 0.15, #0.25
            "backbone_dict": backbone_dict,
        }
        model_p2p = Patch2Pix(cfg)

    elif method == "patch2pix_superglue":
        cfg = {
            "ckpt": "../input/patch2pix-weights/patch2pix_pretrained.pth",
            "imsize": 2048,
            "match_threshold": 0.1,
            "backbone_dict": backbone_dict,
            "coarse": {
                "name": 'SuperGlue',
                "imsize": 2048,              # 1024
                "weights": 'outdoor',        # superglue 
                "sinkhorn_iterations": 20,   # superglue : 100 -> 20
                "match_threshold": 0.2,      # superglue : 0.2
                "max_keypoints": 2048,       # superpoint: 2048
                "nms_radius": 3,             # superpoint: 4 -> 3
                "keypoint_threshold": 0.005  # superpoint: 0.005
            }
        }
        model_p2p = Patch2PixRefined(cfg)  
    
    # for test
    if False:
        im1 = "../input/image-matching-challenge-2022/test_images/1cf87530/0143f47ee9e54243a1b8454f3e91621a.png"
        im2 = "../input/image-matching-challenge-2022/test_images/1cf87530/a5a9975574c94ff9a285f58c39b53d2c.png"
        
        matches, kpts1, kpts2, scores = model_p2p.match_pairs(im1, im2)   
        drawMatches(cv2.imread(im1), matches[:, :2], cv2.imread(im2), matches[:, 2:], np.ones(len(matches)))
        drawMatches(cv2.imread(im1), kpts1, cv2.imread(im2), kpts2, np.ones(len(kpts1)))
else:
    model_p2p = None

## kornia installation & LoFTR

In [None]:
!pip install ../input/kornia-loftr/kornia-0.6.4-py2.py3-none-any.whl
!pip install ../input/kornia-loftr/kornia_moons-0.1.9-py3-none-any.whl

import kornia
from kornia_moons.feature import *
import kornia as K
import kornia.feature as KF

# loftr
if CFG.use_loftr:
    matcher_loftr = KF.LoFTR(pretrained=None)
    matcher_loftr.load_state_dict(torch.load("../input/kornia-loftr/loftr_outdoor.ckpt")['state_dict'])
    matcher_loftr = matcher_loftr.to(device).eval()
else:
    matcher_loftr = None

## MatchFormer

In [None]:
sys.path.append('../input/einops/einops-master')

if CFG.use_matchformer:    
    sys.path.append('../input/pytorchimagemodels/pytorch-image-models-master')
    sys.path.append('../input/matchformer/MatchFormer-main')

    from yacs.config import CfgNode as CN
    from model.matchformer import Matchformer
    from config import defaultmf

    cfg = defaultmf.get_cfg_defaults()
    cfg.MATCHFORMER.BACKBONE_TYPE = 'largela'
    cfg.MATCHFORMER.SCENS = 'outdoor'
    cfg.MATCHFORMER.RESOLUTION = (8,2)
    cfg.MATCHFORMER.MATCH_COARSE.THR = 0.15 #0.2

    def lower_config(yacs_cfg):
        if not isinstance(yacs_cfg, CN):
            return yacs_cfg
        return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}

    _cfg = lower_config(cfg)

    matcher_mf = Matchformer(_cfg['matchformer'])

    pretrained_ckpt = '../input/matchformer/outdoor-large-LA.ckpt'
    matcher_mf.load_state_dict({k.replace('matcher.',''):v  for k,v in torch.load(pretrained_ckpt, map_location='cpu').items()})
    matcher_mf = matcher_mf.to(device).eval()
else:
    matcher_mf = None

## Super Glue

In [None]:
if CFG.use_superglue:
    sys.path.append("../input/super-glue-pretrained-network")
    from models.matching import Matching
    from models.utils import (compute_pose_error, compute_epipolar_error,
                              estimate_pose, make_matching_plot,
                              error_colormap, AverageTimer, pose_auc, read_image,
                              rotate_intrinsics, rotate_pose_inplane,
                              scale_intrinsics)

    config = {
        "superpoint": {
            "nms_radius": 3, #4,
            "keypoint_threshold": 0.005,
            "max_keypoints": 2048,
        },
        "superglue": {
            "weights": "outdoor",
            "sinkhorn_iterations": 20,
            "match_threshold": 0.2,
        }
    }
    matcher_sg = Matching(config).eval().to(device)
else:
    matcher_sg = None

## Quad Tree Attention

In [None]:
if CFG.use_quadtreeattention:
    !cp -r ../input/quadtreeattention/QuadTreeAttention-master/* .
    %cd ./QuadTreeAttention
    !pip install -e . --no-index --no-deps
    %cd /kaggle/working
    sys.path.insert(0, "/kaggle/working/FeatureMatching")
    sys.path.insert(0, "/kaggle/working/QuadTreeAttention")
    
    from src.loftr import LoFTR
    from src.config.default import get_cfg_defaults
    from src.loftr.utils.cvpr_ds_config import default_cfg
    from yacs.config import CfgNode as CN
    from configs.loftr.outdoor.loftr_ds_quadtree import cfg
    
    cfg.LOFTR.MATCH_COARSE.THR = 0.15

    def lower_config(yacs_cfg):
        if not isinstance(yacs_cfg, CN):
            return yacs_cfg
        return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}

    cfg = lower_config(cfg)["loftr"]

    torch.set_grad_enabled(False)

    # Initialize LoFTR
    matcher_qta = LoFTR(config=cfg)
    matcher_qta.load_state_dict(torch.load("../input/quadtreeattention/outdoor.ckpt")['state_dict'])
    matcher_qta = matcher_qta.eval().to(device=device)
else:
    matcher_qta = None

### Segmentation (not improving CV and LB score)

In [None]:
if CFG.use_segmentation:
    sys.path.append('../input/semantic-segmentation/semantic-segmentation-main')
    from semseg import show_models
    from semseg.models import *
    show_models()

    segmodel = eval('SegFormer')(backbone='MiT-B1', num_classes=150)
    segmodel.load_state_dict(torch.load('../input/semantic-segmentation/segformer.b1.ade.pth'))

    #segmodel = eval('SegFormer')(backbone='MiT-B3', num_classes=150)
    #segmodel.load_state_dict(torch.load('../input/semantic-segmentation/segformer.b3.ade.pth'))

    segmodel = segmodel.to(device).eval()
else:
    segmodel = None


def erode(img, target_label):
    assert img.ndim == 2
    img_bin = np.zeros(img.shape).astype(np.uint8)
    img_bin[img == target_label] = 1
    kernel = np.ones((5,5), np.uint8)
    erosion = cv2.erode(img_bin, kernel, iterations=1)
    img[(img_bin == 1) & (erosion != 1)] = 1
    return img


def segmentation(img, segmodel, size=(512,512), device='cuda'):
    image = cv2.resize(img, size)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = image.transpose((2,0,1))
    image = torch.from_numpy(image)
    image = image.float() / 255
    image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
    image = image.unsqueeze(0).to(device)

    with torch.inference_mode():
        seg_org = segmodel(image)

    segmap = seg_org.softmax(1).argmax(1).to(int)
    segmap = segmap.to('cpu').numpy().squeeze(0)    
    segmap = cv2.resize(segmap, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
    
    return segmap


def masking(img, segmodel, size=(512, 512), device='cuda'):
    if segmodel is not None:
        image = cv2.resize(img, size)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image.transpose((2,0,1))
        image = torch.from_numpy(image)
        image = image.float() / 255
        image = T.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))(image)
        image = image.unsqueeze(0).to(device)
        
        with torch.inference_mode():
            seg_org = segmodel(image)
        
        segmap = seg_org.softmax(1).argmax(1).to(int)
        segmap = segmap.to('cpu').numpy().squeeze(0)
        
        remove_label = [2, 12, 20, 116, 127]
        
        for r in remove_label:
            segmap = erode(segmap, r)
        segmap = cv2.resize(segmap, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST)
        for r in remove_label:
            img[segmap==r, :] = np.array([0,0,0])
            
    return img

## Keypoint Matching, F Matrix Estimation, Utilities

In [None]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])


def get_F_matrix(mkpts0, mkpts1):
    inliers = None
    if len(mkpts0) > 8:
        F, mask = cv2.findFundamentalMat(mkpts0, mkpts1, cv2.USAC_MAGSAC, CFG.magsac_thresh, CFG.magsac_conf, CFG.magsac_maxiter)
        inliers = mask.flatten().astype(np.uint8)

        assert F.shape == (3, 3), 'Malformed F?'
    else:
        F = np.zeros((3, 3))
        inliers = np.zeros(mkpts0.shape[0]).astype(np.uint8)

    return F, inliers


%matplotlib inline
def drawMatches(img0, mkpts0, img1, mkpts1, inliers):
    w0 = img0.shape[1]
    w1 = img1.shape[1]
    h0 = img0.shape[0]
    h1 = img1.shape[0]
    W = w0 + w1
    H = max(h0, h1)
    dst = np.zeros((H, W, 3)).astype(np.uint8)
    dst[:h0, :w0, :] = img0
    dst[:h1, w0:, :] = img1
    
    for idx, (p0, p1, inlier) in enumerate(zip(mkpts0, mkpts1, inliers)):
        if inlier > 0:
            cv2.line(dst, (int(p0[0]), int(p0[1])), (int(w0 + p1[0]), int(p1[1])), (0,255,0), 1)
            
    for idx, (p0, p1, inlier) in enumerate(zip(mkpts0, mkpts1, inliers)):
        if inlier > 0:
            color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
            cv2.circle(dst, (int(p0[0]), int(p0[1])), 5, color, 2)
            cv2.circle(dst, (int(w0 + p1[0]), int(p1[1])), 5, color, 2)

    dst = dst[:,:,::-1]

    plt.figure(figsize=(15,10))
    plt.imshow(dst)
    plt.show()


def load_torch_image(fname, device, longest_imgsize, segmodel, padding=False):
    img = cv2.imread(fname)
    org_w = img.shape[1]
    org_h = img.shape[0]
    scale = longest_imgsize / max(img.shape[0], img.shape[1]) 
    w = int(img.shape[1] * scale)
    h = int(img.shape[0] * scale)
    img = cv2.resize(img, (w, h))
    if padding:
        org_w = max(org_w, org_h)
        org_h = max(org_w, org_h)
        base = np.zeros((longest_imgsize, longest_imgsize, 3)).astype(np.uint8)
        base[:img.shape[0], :img.shape[1], :] = img
        img = base
    if segmodel is not None:
        img = masking(img, segmodel, size=(512,512), device=device)
    img = K.image_to_tensor(img, False).float() / 255.
    img = K.color.bgr_to_rgb(img)
    return img.to(device), org_w, org_h


def match_mf(img_path0, img_path1, matcher, device, segmodel):
    img0, org_w0, org_h0 = load_torch_image(img_path0, device=device, longest_imgsize=CFG.longest_imgsize_mf, segmodel=None, padding=True)
    img1, org_w1, org_h1 = load_torch_image(img_path1, device=device, longest_imgsize=CFG.longest_imgsize_mf, segmodel=None, padding=True)

    input_dict = {"image0": K.color.rgb_to_grayscale(img0), 
                  "image1": K.color.rgb_to_grayscale(img1)}

    with torch.inference_mode():
        matcher(input_dict)
    
    conf = input_dict['mconf'].to('cpu').numpy()
    mkpts0 = input_dict['mkpts0_f'].to('cpu').numpy()
    mkpts1 = input_dict['mkpts1_f'].to('cpu').numpy()
    
    sorted_idx = np.argsort(-conf)
    if len(conf) > CFG.max_num_pairs_mf:
        mkpts0 = mkpts0[sorted_idx[:CFG.max_num_pairs_mf], :]
        mkpts1 = mkpts1[sorted_idx[:CFG.max_num_pairs_mf], :]
    
    mkpts0[:,0] = mkpts0[:,0] * org_w0 / img0.shape[3]
    mkpts0[:,1] = mkpts0[:,1] * org_h0 / img0.shape[2]
    
    mkpts1[:,0] = mkpts1[:,0] * org_w1 / img1.shape[3]
    mkpts1[:,1] = mkpts1[:,1] * org_h1 / img1.shape[2]

    return mkpts0, mkpts1


def match_loftr(img_path0, img_path1, matcher, device, segmodel):

    img0, org_w0, org_h0 = load_torch_image(img_path0, device=device, longest_imgsize=CFG.longest_imgsize_loftr, segmodel=None, padding=True)
    img1, org_w1, org_h1 = load_torch_image(img_path1, device=device, longest_imgsize=CFG.longest_imgsize_loftr, segmodel=None, padding=True)

    input_dict = {"image0": K.color.rgb_to_grayscale(img0), 
                  "image1": K.color.rgb_to_grayscale(img1)}
   
    with torch.inference_mode():
        correspondences = matcher(input_dict)

    conf = correspondences['confidence'].cpu().numpy()
    mkpts0 = correspondences['keypoints0'].cpu().numpy()
    mkpts1 = correspondences['keypoints1'].cpu().numpy()
    
    sorted_idx = np.argsort(-conf)
    if len(conf) > CFG.max_num_pairs_loftr:
        mkpts0 = mkpts0[sorted_idx[:CFG.max_num_pairs_loftr], :]
        mkpts1 = mkpts1[sorted_idx[:CFG.max_num_pairs_loftr], :]
    
    mkpts0[:,0] = mkpts0[:,0] * org_w0 / img0.shape[3]
    mkpts0[:,1] = mkpts0[:,1] * org_h0 / img0.shape[2]
    
    mkpts1[:,0] = mkpts1[:,0] * org_w1 / img1.shape[3]
    mkpts1[:,1] = mkpts1[:,1] * org_h1 / img1.shape[2]

    return mkpts0, mkpts1


def match_qta(img_path0, img_path1, matcher, device, segmodel):

    img0, org_w0, org_h0 = load_torch_image(img_path0, device=device, longest_imgsize=CFG.longest_imgsize_qta, segmodel=None, padding=True)
    img1, org_w1, org_h1 = load_torch_image(img_path1, device=device, longest_imgsize=CFG.longest_imgsize_qta, segmodel=None, padding=True)

    input_dict = {"image0": K.color.rgb_to_grayscale(img0), 
                  "image1": K.color.rgb_to_grayscale(img1)}
   
    with torch.inference_mode():
        matcher(input_dict)
        
    mkpts0 = input_dict['mkpts0_f'].cpu().numpy()
    mkpts1 = input_dict['mkpts1_f'].cpu().numpy()
    conf = input_dict['mconf'].cpu().numpy()

    sorted_idx = np.argsort(-conf)
    if len(conf) > CFG.max_num_pairs_loftr:
        mkpts0 = mkpts0[sorted_idx[:CFG.max_num_pairs_qta], :]
        mkpts1 = mkpts1[sorted_idx[:CFG.max_num_pairs_qta], :]
    
    mkpts0[:,0] = mkpts0[:,0] * org_w0 / img0.shape[3]
    mkpts0[:,1] = mkpts0[:,1] * org_h0 / img0.shape[2]
    
    mkpts1[:,0] = mkpts1[:,0] * org_w1 / img1.shape[3]
    mkpts1[:,1] = mkpts1[:,1] * org_h1 / img1.shape[2]

    return mkpts0, mkpts1


def match_sg(img_path0, img_path1, matcher, device, segmodel):
   
    inp_1, org_w0, org_h0 = load_torch_image(img_path0, device=device, longest_imgsize=CFG.longest_imgsize_sg[0], segmodel=segmodel, padding=False)
    inp_2, org_w1, org_h1 = load_torch_image(img_path1, device=device, longest_imgsize=CFG.longest_imgsize_sg[0], segmodel=segmodel, padding=False)

    with torch.inference_mode():
        pred = matcher_sg({"image0": K.color.rgb_to_grayscale(inp_1), 
                           "image1": K.color.rgb_to_grayscale(inp_2)})
    pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()}
    kpts1, kpts2 = pred["keypoints0"], pred["keypoints1"]
    matches, conf = pred["matches0"], pred["matching_scores0"]

    valid = matches > -1
    mkpts0 = kpts1[valid]
    mkpts1 = kpts2[matches[valid]]
    mconf = conf[valid]
    
    sorted_idx = np.argsort(-mconf)
    if len(mconf) > CFG.max_num_pairs_sg:
        mkpts0 = mkpts0[sorted_idx[:CFG.max_num_pairs_sg], :]
        mkpts1 = mkpts1[sorted_idx[:CFG.max_num_pairs_sg], :]
       
    mkpts0[:,0] = mkpts0[:,0] * org_w0 / inp_1.shape[3]
    mkpts0[:,1] = mkpts0[:,1] * org_h0 / inp_1.shape[2]
    
    mkpts1[:,0] = mkpts1[:,0] * org_w1 / inp_2.shape[3]
    mkpts1[:,1] = mkpts1[:,1] * org_h1 / inp_2.shape[2]

    for lsize in CFG.longest_imgsize_sg[1:]:
        inp_1, org_w0, org_h0 = load_torch_image(img_path0, device=device, longest_imgsize=lsize, segmodel=segmodel, padding=False)
        inp_2, org_w1, org_h1 = load_torch_image(img_path1, device=device, longest_imgsize=lsize, segmodel=segmodel, padding=False)

        with torch.inference_mode():
            pred = matcher_sg({"image0": K.color.rgb_to_grayscale(inp_1), 
                               "image1": K.color.rgb_to_grayscale(inp_2)})
        pred = {k: v[0].detach().cpu().numpy() for k, v in pred.items()}
        kpts1, kpts2 = pred["keypoints0"], pred["keypoints1"]
        matches, conf = pred["matches0"], pred["matching_scores0"]

        valid = matches > -1
        mkpts0_small = kpts1[valid]
        mkpts1_small = kpts2[matches[valid]]
        mconf_small = conf[valid]
        
        sorted_idx = np.argsort(-mconf_small)
        if len(mconf_small) > CFG.max_num_pairs_sg:
            mkpts0_small = mkpts0_small[sorted_idx[:CFG.max_num_pairs_sg], :]
            mkpts1_small = mkpts1_small[sorted_idx[:CFG.max_num_pairs_sg], :]
        
        mconf = np.concatenate([mconf, mconf_small])

        mkpts0_small[:,0] = mkpts0_small[:,0] * org_w0 / inp_1.shape[3]
        mkpts0_small[:,1] = mkpts0_small[:,1] * org_h0 / inp_1.shape[2]

        mkpts1_small[:,0] = mkpts1_small[:,0] * org_w1 / inp_2.shape[3]
        mkpts1_small[:,1] = mkpts1_small[:,1] * org_h1 / inp_2.shape[2]

        mkpts0 = np.vstack([mkpts0, mkpts0_small])
        mkpts1 = np.vstack([mkpts1, mkpts1_small])
    
    return mkpts0, mkpts1


def match_p2p(img1_path, img2_path, model_p2p, segmodel):
    with torch.inference_mode():
        matches, kpts1, kpts2, scores = model_p2p.match_pairs(img1_path, img2_path)
    mkpts0 = matches[:, :2]
    mkpts1 = matches[:, 2:]
    
    sorted_idx = np.argsort(-scores)
    if len(matches) > CFG.max_num_pairs_p2p:
        mkpts0 = mkpts0[sorted_idx[:CFG.max_num_pairs_p2p], :]
        mkpts1 = mkpts1[sorted_idx[:CFG.max_num_pairs_p2p], :]
        
    return mkpts0, mkpts1

## Inference

In [None]:
def inference_Fstr(img1_path, img2_path, matcher_mf, matcher_loftr, matcher_sg, model_p2p, matcher_qta, segmodel, device):
    if matcher_mf is not None:
        mkpts0_1, mkpts1_1 = match_mf(img1_path, img2_path, matcher_mf, device=device, segmodel=segmodel) 
    else:
        mkpts0_1 = np.zeros((0,2))
        mkpts1_1 = np.zeros((0,2))
        
    if matcher_loftr is not None:
        mkpts0_2, mkpts1_2 = match_loftr(img1_path, img2_path, matcher_loftr, device=device, segmodel=segmodel)
    else:
        mkpts0_2 = np.zeros((0,2))
        mkpts1_2 = np.zeros((0,2))
        
    if matcher_sg is not None:
        mkpts0_3, mkpts1_3 = match_sg(img1_path, img2_path, matcher_sg, device=device, segmodel=segmodel)
    else:
        mkpts0_3 = np.zeros((0,2))
        mkpts1_3 = np.zeros((0,2))
        
    if model_p2p is not None:
        mkpts0_4, mkpts1_4 = match_p2p(img1_path, img2_path, model_p2p, segmodel=segmodel)
    else:
        mkpts0_4 = np.zeros((0,2))
        mkpts1_4 = np.zeros((0,2))
        
    if matcher_qta is not None:
        mkpts0_5, mkpts1_5 = match_qta(img1_path, img2_path, matcher_qta, device=device, segmodel=segmodel)
    else:
        mkpts0_5 = np.zeros((0,2))
        mkpts1_5 = np.zeros((0,2))        
    
    mkpts0 = np.vstack([mkpts0_1, mkpts0_2, mkpts0_3, mkpts0_4, mkpts0_5])
    mkpts1 = np.vstack([mkpts1_1, mkpts1_2, mkpts1_3, mkpts1_4, mkpts1_5])

    F, inliers = get_F_matrix(mkpts0, mkpts1)
    
    F_str = FlattenMatrix(F)
    return F_str, mkpts0, mkpts1, inliers


def inference(matcher_mf, matcher_loftr, matcher_sg, model_p2p, matcher_qta, mode, device, segmodel):
    sub = pd.DataFrame(columns = ['sample_id', 'fundamental_matrix'])
    
    if CFG.mode == "test":
        src = '../input/image-matching-challenge-2022/'
        test_samples = []
        with open(f'{src}/test.csv') as f:
            reader = csv.reader(f, delimiter=',')
            for i, row in enumerate(reader):
                if i == 0:
                    continue
                row[2] = f'{src}/test_images/{row[1]}/{row[2]}.png'
                row[3] = f'{src}/test_images/{row[1]}/{row[3]}.png'
                test_samples += [row]
                
        pairs = test_samples

    elif CFG.mode == "val":
        src = '../input/image-matching-challenge-2022/train'
        scaling_dict = {}
        with open(f'{src}/scaling_factors.csv') as f:
            reader = csv.reader(f, delimiter=',')
            for i, row in enumerate(reader):
                if i == 0:
                    continue
                scaling_dict[row[0]] = float(row[1])
                
        val_samples = []
        for scene_id, scene in enumerate(scaling_dict.keys()):
            if len(CFG.validate_scene_id) > 0 and scene_id not in CFG.validate_scene_id:
                continue
            covisibility_dict, F_gt_dict = ReadCovisibilityData(f'{src}/{scene}/pair_covisibility.csv')
            pairs = list([key for key, covis in covisibility_dict.items() if covis >= 0.1])
            random.shuffle(pairs)
            n = len(pairs)
            pairs = pairs[:CFG.validate_num_pairs]
            print(f'Loading covisibility data for "{scene}"... kept {len(pairs)} out of {n} covisible pairs')

            for pair in pairs:
                image_1_id, image_2_id = pair.split('-')
                image_1 = f'{src}/{scene}/images/{image_1_id}.jpg'
                image_2 = f'{src}/{scene}/images/{image_2_id}.jpg'
                val_samples.append([f'phototourism;{scene};{pair}', 0, image_1, image_2])
                
        pairs = val_samples

    
    for i, row in enumerate(tqdm(pairs)):
        sample_id, batch_id, image_1, image_2 = row
        F_str, mkpts0, mkpts1, inliers = inference_Fstr(
            image_1, image_2, matcher_mf, matcher_loftr, matcher_sg, model_p2p, matcher_qta, segmodel, device)
        sub = sub.append({'sample_id': f'{sample_id}', 'fundamental_matrix': f'{F_str}'}, ignore_index=True) 
        
        # visualize
        if i < 3:
            drawMatches(cv2.imread(image_1), mkpts0, cv2.imread(image_2), mkpts1, inliers)
    
    return sub

In [None]:
sub = inference(matcher_mf, matcher_loftr, matcher_sg, model_p2p, matcher_qta, CFG.mode, device, segmodel)
sub.to_csv('submission.csv', index=False)

## Checking Vaidation Score

In [None]:
if CFG.mode == 'val':
    src = '../input/image-matching-challenge-2022/train'
    scaling_dict = {}
    with open(f'{src}/scaling_factors.csv') as f:
        reader = csv.reader(f, delimiter=',')
        for i, row in enumerate(reader):
            if i == 0:
                continue
            scaling_dict[row[0]] = float(row[1])
    
    thresholds_q = np.linspace(1, 10, 10)
    thresholds_t = np.geomspace(0.2, 5, 10)

    print('--- Evaluate prediction ---')
    maa, maa_per_scene, errors_dict_q, errors_dict_t = EvaluateSubmission('submission.csv', scaling_dict, thresholds_q, thresholds_t, src=src)
    for scene, cur_maa in maa_per_scene.items():
        print(f'Scene "{scene}" ({len(errors_dict_q[scene])} pairs), mAA={cur_maa:.05f}')
    print()
    print(f'Full dataset: mAA={maa:.05f}')
    print()