In [None]:
import hnswlib
from PIL import Image
from MoCoFeatureExtractor import MoCoFeatureExtractor
from glob import glob
import random
import numpy as np
import ipyplot
from tqdm import tqdm
import math
import pickle

### Parameters and Utility methods for extracting patches

In [None]:
version = 'e'

window_size = 224
image_sizes = [224, 448, 672]
stride = 112
feature_dim = 2048

# cnn = ResNetFeatureExtractor()
cnn = MoCoFeatureExtractor(params_path='/Users/racoon/Desktop/moco_v2_800ep_pretrain.pth.tar')
image_files = glob("dataset_1000/train/*/*.jpg")


def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]

### Extract features from each image for each image size

In [None]:
for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):
    
    pil_image = Image.open(image_file).convert('RGB')
    
    image_grids = {}
    
    for image_size in image_sizes:
        if pil_image.size[1] > pil_image.size[0]:
            resized_image = pil_image.resize((image_size, int(round(pil_image.size[1]/pil_image.size[0] * image_size))))
        else:
            resized_image = pil_image.resize((int(round(pil_image.size[0]/pil_image.size[1] * image_size)), image_size))
        
        image = np.array(resized_image)
        
        margin = window_size-stride
        grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
        offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

        points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]
        
        patches = extract_windows(image, points, window_size)
        windows = patches.astype(np.float64)

        try:
            feats = cnn.evalRGB(windows)
        except:
            print("ERROR cnn.evalRGB", image_file, image.shape, windows.shape)
            raise

        feat_grid = feats.reshape((grid_shape[0], grid_shape[1], feature_dim))
        
        image_grids[str(image_size)] = feat_grid
    
    path_parts = image_file.split('/')
    image_id = path_parts[-1].split('.')[0]
    image_class = path_parts[-2]

    pathlib.Path(f'feat_grids_{window_size}_{stride}_{version}/{image_class}').mkdir(parents=True, exist_ok=True)
    np.savez_compressed(f'feat_grids_{window_size}_{stride}_{version}/{image_class}/{image_id}.npz', **image_grids)

### Create nn index of all features from training images of a given size

In [None]:
image_size = 224

npz_files = glob(f'/Users/racoon/Desktop/feat_grids_{window_size}_{stride}_{version}/*/*.npz')

index = hnswlib.Index(space='cosine', dim=2048) 
index.init_index(max_elements=1000000, ef_construction=300 * 2, M=64)
index.set_ef(300)

id = 0

patch_dict = {}

for idx, npz_file in tqdm(enumerate(npz_files), total=len(npz_files)):
    
    loaded = np.load(npz_file)
    feat_grid = loaded[str(image_size)]
    feats = feat_grid.reshape((feat_grid.shape[0]*feat_grid.shape[1], feat_grid.shape[2]))
    
    # print(feat_grid.shape, feats.shape)
    
    ids = [i for i in range(id, id + feats.shape[0])]
    id += feats.shape[0]
    
    index.add_items(feats, ids)

    for j in range(feats.shape[0]):
        patch_dict[ids[j]] = npz_file
        
index.save_index(f'nn_index_{str(image_size)}.idx')
pickle.dump(patch_dict, open(f'patch_dict_{str(image_size)}.pkl', 'wb'))

### Headline

In [None]:
def mask_locations(mask, stride, grid_shape, offsets):
    
    object_grid_locations = set()

    for y in range(grid_shape[0]):
        for x in range(grid_shape[1]):
            p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
            w = extract_window(mask, p, stride)

            # print(np.sum(w))
            #if np.sum(w) >= stride * stride * 0.3:
            object_grid_locations.add((y, x))

    return object_grid_locations

test_image_files = glob("dataset_1000/test/*/*.jpg")
test_mask_files = glob("dataset_1000/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()

correct = 0
total = 0

for i in range(len(test_image_files)):
    
    image_correct = 0
    image_total = 0
    
    image_file = test_image_files[i]
    mask_file = test_mask_files[i]

    print(image_file)
    
    pil_image = Image.open(image_file).convert('RGB')

    if pil_image.size[1] > pil_image.size[0]:
        resized_image = pil_image.resize((image_size, int(round(pil_image.size[1]/pil_image.size[0] * image_size))))
    else:
        resized_image = pil_image.resize((int(round(pil_image.size[0]/pil_image.size[1] * image_size)), image_size))

    image = np.array(resized_image)
        
            
    pil_mask = Image.open(mask_file).convert('1')

    if pil_mask.size[1] > pil_mask.size[0]:
        resized_mask = pil_mask.resize((image_size, int(round(pil_mask.size[1]/pil_mask.size[0] * image_size))))
    else:
        resized_mask = pil_mask.resize((int(round(pil_mask.size[0]/pil_mask.size[1] * image_size)), image_size))
        
    mask = np.array(resized_mask)
    
        
    margin = window_size-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations = list(mask_locations(mask, stride, grid_shape, offsets))
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations]
    print('len(points)', len(points))
    
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))
    
    nn_ids, _ = index.knn_query(feats, 10)
    
    for foo in nn_ids:
        nn_files = [patch_dict.get(q) for q in foo]

        for f in nn_files:

            a = image_file.split('/')[-2]
            b = f.split('/')[-2]

            # print(a, b)

            if a == b:
                image_correct += 1
                correct += 1

            image_total += 1
            total += 1

    print("score", image_correct/image_total)
    
print("final score", correct/total)

## testing with dataset_100
672: final score 0.8107003891050584  
448: final score 0.8972972972972973  
224: final score 0.9666666666666667

## testing with dataset_1000
224: final score 0.8968137254901961