In [30]:
from MoCoFeatureExtractor import MoCoFeatureExtractor
import hnswlib
from glob import glob
from PIL import Image
import numpy as np
import math

In [31]:
window_size = 96
stride = 24
n_clusters = 100
image_scale = 1/3

feature_lenth = 2048

cnn = MoCoFeatureExtractor()

def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]

In [33]:
image_files = glob("dataset_100/train/*/*.jpg")

patch_file_names = []
X = []

for idx, image_file in enumerate(image_files):
    print(idx, image_scale, image_file)
    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    margin = window_size
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]

    patches = extract_windows(image, points, window_size)

    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_lenth))
    X.extend(list(feats))
    patch_file_names.extend([image_file]*len(windows))

    
index = hnswlib.Index(space='cosine', dim=feature_lenth) 
index.init_index(max_elements=250000, ef_construction=300 * 2, M=64)
index.set_ef(300)
index.add_items(X, list(range(len(X))))

print("Done")

0 0.3333333333333333 dataset_100/train/car/46938b4c628ce00e.jpg
1 0.3333333333333333 dataset_100/train/car/8c4b9d096f6423ed.jpg
2 0.3333333333333333 dataset_100/train/car/fc418b3caef440aa.jpg
3 0.3333333333333333 dataset_100/train/car/eac45380074ba8c8.jpg
4 0.3333333333333333 dataset_100/train/car/ea9a6d46a1279f85.jpg
5 0.3333333333333333 dataset_100/train/car/9893ae3d876f9c1c.jpg
6 0.3333333333333333 dataset_100/train/car/f6b73bb2536fdcb7.jpg
7 0.3333333333333333 dataset_100/train/car/c1ae01ffc0c505f4.jpg
8 0.3333333333333333 dataset_100/train/car/0a41cda5f44baaf6.jpg
9 0.3333333333333333 dataset_100/train/car/0c9f9b713f229fba.jpg
10 0.3333333333333333 dataset_100/train/car/3fe04f7604431846.jpg
11 0.3333333333333333 dataset_100/train/car/d14658b78cec2cf7.jpg
12 0.3333333333333333 dataset_100/train/car/cef82902c2f6cac2.jpg
13 0.3333333333333333 dataset_100/train/car/87a618d5e8e769f4.jpg
14 0.3333333333333333 dataset_100/train/car/240d6d636ae47e4c.jpg
15 0.3333333333333333 dataset_100/t

In [34]:
def object_grid_locations(image, stride, grid_shape, offsets, mask):
    
    grid_locations = set()

    for y in range(grid_shape[0]):
        for x in range(grid_shape[1]):
            p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
            w = extract_window(mask, p, stride)

            if np.sum(w) >= stride * stride * 0.3:
                grid_locations.add((y, x))
                
    return grid_locations


def extract_object_feats(image_file, mask_file, feature_extractor):

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    pil_mask = Image.open(mask_file).convert('1')
    pil_mask = pil_mask.resize((int(round(pil_mask.size[0] * image_scale)), int(round(pil_mask.size[1] * image_scale))))
    mask = np.array(pil_mask)
         
    margin = window_size
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = object_grid_locations(image, stride, grid_shape, offsets, mask)
    grid_locations_list = list(grid_locations_set)
    
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
         
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_lenth))
        
    return feats

In [36]:
test_image_files = glob("dataset_100/test/*/*.jpg")
test_mask_files = glob("dataset_100/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()


correct = 0
total = 0


for i in range(len(test_image_files)):

    similar_files_set = set()

    image_file = test_image_files[i]
    mask_file = test_mask_files[i]

    print("test", image_file)

    try:
        feats = extract_object_feats(image_file, mask_file, cnn)
    except:
        print("error")
        continue
        
    nn_ids, nn_dis = index.knn_query(feats, 10)
    
    for s in nn_ids:
        for n in s:
            f = patch_file_names[n]

            if '/airplane/' in image_file and '/airplane/' in f:
                similar_files_set.add(f)
                correct += 1
            elif '/car/' in image_file and '/car/' in f:
                similar_files_set.add(f)
                correct += 1
            elif '/horse/' in image_file and '/horse/' in f:
                similar_files_set.add(f)
                correct += 1


            total += 1

    print('len(similar_files_set)', len(similar_files_set))


print("score", correct/total)

test dataset_100/test/airplane/35b11a04c24db20c.jpg
len(similar_files_set) 18
test dataset_100/test/airplane/4ad0f079b979be5d.jpg
len(similar_files_set) 19
test dataset_100/test/airplane/839ce813ca97084c.jpg
len(similar_files_set) 27
test dataset_100/test/airplane/93b5bf58149adefd.jpg
len(similar_files_set) 12
test dataset_100/test/airplane/9dc879c35a26d2d3.jpg
len(similar_files_set) 19
test dataset_100/test/airplane/a48f1d15812036fa.jpg
len(similar_files_set) 21
test dataset_100/test/airplane/b6ac22d7db1769ee.jpg
len(similar_files_set) 26
test dataset_100/test/airplane/d5422871fd63b8b8.jpg
len(similar_files_set) 12
test dataset_100/test/airplane/e95bc413d4b748ba.jpg
len(similar_files_set) 20
test dataset_100/test/airplane/fbe835c5944f93e5.jpg
len(similar_files_set) 14
test dataset_100/test/car/455c29cd8db5b225.jpg
len(similar_files_set) 36
test dataset_100/test/car/56d1d8aca15ae219.jpg
len(similar_files_set) 35
test dataset_100/test/car/89297009b1d18663.jpg
error
test dataset_100/test