### faiss 적용

In [23]:
import cv2
import numpy as np
import faiss

def compute_vlad(image, des, labels, centers):
    sift = cv2.SIFT_create(edgeThreshold=80)
    _, des = sift.detectAndCompute(image, None)
    
    if des is None: 
        return None
    
    vlad = np.zeros((centers.shape[0], des.shape[1]), dtype=np.float32)
    
    for i in range(des.shape[0]):
        vlad[labels[i]] += des[i] - centers[labels[i]]
        
    vlad = cv2.normalize(vlad, None).flatten()
    vlad /= np.linalg.norm(vlad)
    return vlad

def similar_images(query_image, category_images, k):
    sift = cv2.SIFT_create(edgeThreshold=80)
    
    query_image = cv2.resize(query_image, (300, 300))
    _, query_des = sift.detectAndCompute(query_image, None)
    if query_des is None:
        return None

    category_des = []
    for category, images in category_images.items():
        for image in images:
            _, des = sift.detectAndCompute(image, None)
            if des is not None:
                category_des.extend(des)

    category_des = np.array(category_des)

    criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
    flags = cv2.KMEANS_RANDOM_CENTERS
    _, labels, centers = cv2.kmeans(category_des, k, None, criteria, 5, flags)

    query_vlad = compute_vlad(query_image, query_des, labels, centers)
    if query_vlad is None:
        return None

    similarity_scores = {}
    for category, images in category_images.items():
        category_vlads = []
        for image in images:
            image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
            image = cv2.resize(image, (300, 300))
            _, des = sift.detectAndCompute(image, None)
            if des is not None:
                category_vlads.append(compute_vlad(image, des, labels, centers))

        category_vlads = np.array(category_vlads)

        # Faiss Index 생성
        d = query_vlad.shape[0]
        index = faiss.IndexFlatL2(d)
        index.add(category_vlads)

        # Query 이미지를 Faiss Index에 검색
        distances, _ = index.search(query_vlad.reshape(1, -1), 10)
        similarity_scores[category] = np.mean(distances)

    # 0에 가까운 것이 유사도가 가장 높으므로 오름차순 정렬
    sorted_categories = sorted(similarity_scores.items(), key=lambda x: x[1])
    return sorted_categories

path = 'c:/data/temp/'

# query_image_path =  path + 'flower1.jpg'
# query_image_path =  path + 'mouse1.jpg'
# query_image_path =  path + 'umb1.jpg'
# query_image_path =  path + 'bal1.jpg'
# query_image_path =  path + 'muf1.jpg'
# query_image_path =  path + 'book1.jpg'
# query_image_path =  path + 'box1.jpg'
# query_image_path =  path + 'game1.jpg'
# query_image_path =  path + 'sci1.jpg'
query_image_path =  path + 'shirt1.jpg'

categories = {
    'mouse': [cv2.imread(path + 'mouse11.jpg'), cv2.imread(path + 'mouse2.jpg'), cv2.imread(path + 'mouse3.jpg'), 
              cv2.imread(path + 'mouse4.jpg'), cv2.imread(path + 'mouse5.jpg'), cv2.imread(path + 'mouse6.jpg'), 
              cv2.imread(path + 'mouse7.jpg'), cv2.imread(path + 'mouse8.jpg'), cv2.imread(path + 'mouse9.jpg'), 
              cv2.imread(path + 'mouse10.jpg')],
    'flower': [cv2.imread(path + 'flower11.jpg'), cv2.imread(path + 'flower2.jpg'), cv2.imread(path + 'flower3.jpg'), 
               cv2.imread(path + 'flower4.jpg'), cv2.imread(path + 'flower5.jpg'), cv2.imread(path + 'flower6.jpg'), 
               cv2.imread(path + 'flower7.jpg'), cv2.imread(path + 'flower8.jpg'), cv2.imread(path + 'flower9.jpg'), 
               cv2.imread(path + 'flower10.jpg')],
    'umbrella': [cv2.imread(path + 'umb11.jpg'), cv2.imread(path + 'umb2.jpg'), cv2.imread(path + 'umb3.jpg'),
                 cv2.imread(path + 'umb4.jpg'), cv2.imread(path + 'umb5.jpg'), cv2.imread(path + 'umb6.jpg'), 
                 cv2.imread(path + 'umb7.jpg'), cv2.imread(path + 'umb8.jpg'),cv2.imread(path + 'umb9.jpg'), 
                 cv2.imread(path + 'umb10.jpg')],
    'balloon': [cv2.imread(path + 'bal11.jpg'), cv2.imread(path + 'bal2.jpg'), cv2.imread(path + 'bal3.jpg'),
                 cv2.imread(path + 'bal4.jpg'), cv2.imread(path + 'bal5.jpg'), cv2.imread(path + 'bal6.jpg'), 
                 cv2.imread(path + 'bal7.jpg'), cv2.imread(path + 'bal8.jpg'),cv2.imread(path + 'bal9.jpg'), 
                 cv2.imread(path + 'bal10.jpg')],
    'muffin': [cv2.imread(path + 'muf11.jpg'), cv2.imread(path + 'muf2.jpg'), cv2.imread(path + 'muf3.jpg'),
                 cv2.imread(path + 'muf4.jpg'), cv2.imread(path + 'muf5.jpg'), cv2.imread(path + 'muf6.jpg'), 
                 cv2.imread(path + 'muf7.jpg'), cv2.imread(path + 'muf8.jpg'),cv2.imread(path + 'muf9.jpg'), 
                 cv2.imread(path + 'muf10.jpg')],
    'book': [cv2.imread(path + 'book11.jpg'), cv2.imread(path + 'book2.jpg'), cv2.imread(path + 'book3.jpg'),
                 cv2.imread(path + 'book4.jpg'), cv2.imread(path + 'book5.jpg'), cv2.imread(path + 'book6.jpg'), 
                 cv2.imread(path + 'book7.jpg'), cv2.imread(path + 'book8.jpg'),cv2.imread(path + 'book9.jpg'), 
                 cv2.imread(path + 'book10.jpg')],
    'box': [cv2.imread(path + 'box11.jpg'), cv2.imread(path + 'box2.jpg'), cv2.imread(path + 'box3.jpg'),
                 cv2.imread(path + 'box4.jpg'), cv2.imread(path + 'box5.jpg'), cv2.imread(path + 'box6.jpg'), 
                 cv2.imread(path + 'box7.jpg'), cv2.imread(path + 'box8.jpg'),cv2.imread(path + 'box9.jpg'), 
                 cv2.imread(path + 'box10.jpg')],
    'game': [cv2.imread(path + 'game11.jpg'), cv2.imread(path + 'game2.jpg'), cv2.imread(path + 'game3.jpg'),
                 cv2.imread(path + 'game4.jpg'), cv2.imread(path + 'game5.jpg'), cv2.imread(path + 'game6.jpg'), 
                 cv2.imread(path + 'game7.jpg'), cv2.imread(path + 'game8.jpg'),cv2.imread(path + 'game9.jpg'), 
                 cv2.imread(path + 'game10.jpg')],
    'scissors': [cv2.imread(path + 'sci11.jpg'), cv2.imread(path + 'sci2.jpg'), cv2.imread(path + 'sci3.jpg'),
                 cv2.imread(path + 'sci4.jpg'), cv2.imread(path + 'sci5.jpg'), cv2.imread(path + 'sci6.jpg'), 
                 cv2.imread(path + 'sci7.jpg'), cv2.imread(path + 'sci8.jpg'),cv2.imread(path + 'sci9.jpg'), 
                 cv2.imread(path + 'sci10.jpg')],
    'shirt': [cv2.imread(path + 'shirt11.jpg'), cv2.imread(path + 'shirt2.jpg'), cv2.imread(path + 'shirt3.jpg'),
                 cv2.imread(path + 'shirt4.jpg'), cv2.imread(path + 'shirt5.jpg'), cv2.imread(path + 'shirt6.jpg'), 
                 cv2.imread(path + 'shirt7.jpg'), cv2.imread(path + 'shirt8.jpg'),cv2.imread(path + 'shirt9.jpg'), 
                 cv2.imread(path + 'shirt10.jpg')],
}

query_image = cv2.imread(query_image_path, cv2.IMREAD_GRAYSCALE)

k = 2
categories = similar_images(query_image, categories, k)
print('입력 이미지: {}'.format(query_image_path))
print('제일 유사한 카테고리: {}'.format(categories[0]))
print('제일 유사하지 않은 카테고리: {}\n'.format(categories[-1]))
print('0에 가까울 수록 입력 이미지와 유사하다.')
for k, v in categories:
    print("카테고리: %-10s / 유사도: %1.2f"% (k , v))

AssertionError: 