# Imports

In [None]:
import os
import numpy as np
import cv2
import csv
import pandas as pd
from glob import glob
import matplotlib.pyplot as plt
from collections import namedtuple
from copy import deepcopy
from tqdm import tqdm
import random
import gc

# Loading data

# Define some useful functions

In [None]:
def FlattenMatrix(M, num_digits=8):
    '''Convenience function to write CSV files.'''
    
    return ' '.join([f'{v:.{num_digits}e}' for v in M.flatten()])

# We will draw only inliers and tentative matches:
draw_dict={
    'inlier_color': (0.2, 1, 0.2),  # Green: inliers.
    'tentative_color': (1, 1, 0.2, 0.5),  #Light yellow: tentative matches.
    'feature_color': None,
    'vertical': False
}

def ArrayFromCvKps(kps):
    '''Convenience function to convert OpenCV keypoints into a simple numpy array.'''
    
    return np.array([kp.pt for kp in kps])

def ExtractSiftFeatures(image, detector, num_features):
    '''Compute SIFT features for a given image.'''
    
    gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
    kp, desc = detector.detectAndCompute(gray, None)
    return kp[:num_features], desc[:num_features]

def ReadCovisibilityData(filename):
    covisibility_dict = {}
    with open(filename) as f:
        reader = csv.reader(f, delimiter=',')
        for i, row in enumerate(reader):
            # Skip header.
            if i == 0:
                continue
            covisibility_dict[row[0]] = float(row[1])

    return covisibility_dict

# Make submission

In [None]:
test = pd.read_csv('../input/image-matching-challenge-2022/test.csv')
batch_ids = list(test['batch_id'])
sample_ids = list(test['sample_id'])
batch_ids[0]

In [None]:
src = '../input/image-matching-challenge-2022/test_images'

images_dict = {}

for i in range(len(batch_ids)):
    img_list = []
    batch_id = batch_ids[i]
    batch = src + '/' + batch_id
    image_names = os.listdir(batch)
    for img in range(len(image_names)):
        filename = batch + '/' + image_names[img]
        img_array = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
        img_list.append(img_array)      
    images_dict[batch_id] = img_list

In [None]:
num_features = 2000

how_many_to_fill = -1

sift_detector = cv2.SIFT_create(num_features, contrastThreshold=-10000, edgeThreshold=-10000)

F_dict = {}
for batch in range(len(batch_ids)):
    
    sample_id = sample_ids[batch]
    
    batch_id = batch_ids[batch]
    
    if how_many_to_fill >= 0 and i >= how_many_to_fill:
        F_dict[sample_id] = np.random.rand(3, 3)
        continue
    
    
    keypoints_1, descriptors_1 = ExtractSiftFeatures(images_dict[batch_id][0], sift_detector, 2000)
    keypoints_2, descriptors_2 = ExtractSiftFeatures(images_dict[batch_id][1], sift_detector, 2000)
    
    bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=True)

    cv_matches = bf.match(descriptors_1, descriptors_2)

    # Convert keypoints and matches to something more human-readable.
    cur_kp_1 = ArrayFromCvKps(keypoints_1)
    cur_kp_2 = ArrayFromCvKps(keypoints_2)
    matches = np.array([[m.queryIdx, m.trainIdx] for m in cv_matches])
    
    if len(matches) > 8:
        F, inlier_mask = cv2.findFundamentalMat(cur_kp_1[matches[:, 0]], 
                                        cur_kp_2[matches[:, 1]], 
                                        cv2.USAC_MAGSAC, 
                                        ransacReprojThreshold=0.25, 
                                        confidence=0.99999, maxIters=10000)
    
        assert F.shape == (3, 3), 'Malformed F?'
        F_dict[sample_id] = F
    else:
        F_dict[sample_id] = np.zeros((3, 3))
        continue
    gc.collect()
    
    

with open('submission.csv', 'w') as f:
    f.write('sample_id,fundamental_matrix\n')
    for sample_id, F in F_dict.items():
        f.write(f'{sample_id},{FlattenMatrix(F)}\n')