In [26]:
import matplotlib.pyplot as plt
import cv2
import os
import math

In [31]:
# поиск по ключевым точкам, получить peak coord

def matchTemplate(path_img, path_pattern):
    
    # Load the input images
    input_image = cv2.imread(path_img, 0)
    pattern_template = cv2.imread(path_pattern, 0)

    input_name = path_img[path_img.rfind('/') + 1:path_img.rfind('.jpg')]
    pattern_name = path_pattern[path_pattern.rfind('/') + 1:path_pattern.rfind('.jpg')]                                                                     

    path = './key_points'
    is_exist  = os.path.exists('./key_points')
    
    if not os.path.exists(path):
        os.mkdir(path)
        
    if not os.path.exists(path + '/input'):
        os.mkdir(path + '/input')

    if not os.path.exists(path + '/pattern'):
        os.mkdir(path + '/pattern')

    if not os.path.exists('./result'):
        os.mkdir('./result')

    
    sift = cv2.SIFT_create()
    
    input_kp, input_desc = sift.detectAndCompute(input_image, None)
    input_image_kp = cv2.drawKeypoints(input_image, input_kp, input_image)
    kp_input_name = 'key_points/input/kp_' + input_name + '.jpg'
    cv2.imwrite(kp_input_name, input_image_kp)
    
    pattern_kp, pattern_desc = sift.detectAndCompute(pattern_template,None)
    pattern_image_kp = cv2.drawKeypoints(pattern_template, pattern_kp, pattern_template)
    kp_pattern_name = 'key_points/pattern/kp_' + pattern_name + '.jpg'
    cv2.imwrite(kp_pattern_name, pattern_image_kp)

    index_params = dict(algorithm=0, trees=5)
    search_params = dict()
    flann = cv2.FlannBasedMatcher(index_params, search_params)
    matches = flann.knnMatch(input_desc, pattern_desc, k=2)

    points = []


    height, width =  input_image.shape
    
    r = 0
    b = 0
    l = width
    t = height
    
    ratio = 0.4
    for i, j in matches:
        if i.distance < ratio*j.distance:
            points.append(i)
            
            img1_idx = i.queryIdx
            
            (x, y) = input_kp[img1_idx].pt
            if x > r:
                r = math.floor(x) + 1
            if x < l:
                l = math.floor(x)
            if y < t:
                t = math.floor(y)
            if y > b:
                b = math.floor(y) + 1
        
    result = cv2.drawMatches(input_image, input_kp, pattern_template, pattern_kp, points, None)
    cv2.imwrite('result/' + 'result_' + input_name + '.jpg', result)
    

    #cv2.imshow("result", result)
    #cv2.waitKey(0)
    #cv2.destroyAllWindows()


    highlight_start = (l, t)
    highlight_end = (r, b)
    
    return highlight_start, highlight_end


# Additional

In [32]:
def show(input_image, pattern_template, highlight_start, highlight_end,row):
    cv2.rectangle(input_image,(row["x1"],row["y1"]), (row["x2"],row["y2"]),(255,0,0), 2)
    cv2.rectangle(input_image, highlight_start, highlight_end, 255, 2)
    # Visualize the pattern template and resulting image
    fig_instance, axes_arr = plt.subplots(1, 2, figsize=(10, 5))

    # Show the pattern template
    axes_arr[0].imshow(pattern_template, cmap='gray')
    axes_arr[0].set_title('Pattern Template')

    # Show the input image with the highlighted match
    axes_arr[1].imshow(input_image, cmap='gray')
    axes_arr[1].set_title('Pattern Highlighted')

    plt.show()

In [33]:
# iou

def iou(row, highlight_start, highlight_end):
    true_points_set = set([])
    pred_points_set= set([])

    for x in range(row["x1"], row["x2"]):
        for y in range(row["y1"], row["y2"]):
                true_points_set.add((x,y))

        for x in range(highlight_start[0],highlight_end[0]):
            for y in range(highlight_start[1], highlight_end[1]):
                pred_points_set.add((x,y))
    iou_metric = len(true_points_set.intersection(pred_points_set)) / len(true_points_set.union(pred_points_set))
    return iou_metric

# Output

In [34]:
# загрузка изображений, перебор
import statistics
from PIL import Image, ImageFilter
import numpy as np
import pandas as pd
dataset = pd.read_csv('dataset/annotation.csv', delimiter=';')
dataset.head()
iou_array=[]

for i in range(0, 36):
    row = dataset.iloc[i]
    highlight_start, highlight_end =matchTemplate(f"dataset/{row['id']}.jpg", f"dataset/pattern/cropped_img_{row['id']}.jpg")
    iou_metric = iou(row,highlight_start,highlight_end)
    print(iou_metric)
    iou_array.append(iou_metric)

iou_mean = statistics.mean(iou_array)
print("iou_mean",iou_mean)

0.5228365384615384
0.822673031026253
0.7174857142857143
0.2898717948717949
0.8027156549520766
0.44515197826456104
0.8179614641962122
0.6414894596336187
0.20031400573872557
0.48528278625366006


KeyboardInterrupt: 