In [None]:
import hnswlib
from PIL import Image
from MoCoFeatureExtractor import MoCoFeatureExtractor
from ResNetFeatureExtractor import ResNetFeatureExtractor
from glob import glob
import random
import numpy as np
import ipyplot
from tqdm import tqdm
import math
import pickle

In [None]:
feature_dim = 2048
image_scale = 1
window_size = 288
stride = 144

index = hnswlib.Index(space='cosine', dim=2048) 
index.init_index(max_elements=1000000, ef_construction=300 * 2, M=64)
index.set_ef(300)

cnn = ResNetFeatureExtractor()

In [None]:
def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]

In [None]:
image_files = glob("/home/ubuntu/dataset_1000/train/*/*.jpg")

id = 0

patch_dict = {}

for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):
    # print(idx, image_file)
    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2 or image.shape[0] > 1024 or image.shape[1] > 1024:
        continue
        
    margin = window_size-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]
    # print('len(points)', len(points))  
    ids = [i for i in range(id, id + len(points))]
    id += len(points)
    
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    try:
        feats = cnn.evalRGB(windows)
    except:
        print("ERROR cnn.evalRGB", image, image.shape, windows.shape)
        raise
        
    
    feats = feats.reshape((windows.shape[0], feature_dim))

    index.add_items(feats, ids)

    for j in range(len(points)):
        patch_dict[ids[j]] = image_file
        
index.save_index("nn_index_supervised.idx")
pickle.dump(patch_dict, open("patch_dict.pkl_supervised", 'wb'))

In [None]:
def mask_locations(mask, stride, grid_shape, offsets):
    
    object_grid_locations = set()

    for y in range(grid_shape[0]):
        for x in range(grid_shape[1]):
            p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
            w = extract_window(mask, p, stride)

            # print(np.sum(w))
            if np.sum(w) >= stride * stride * 0.3:
                object_grid_locations.add((y, x))

    return object_grid_locations

test_image_files = glob("/home/ubuntu/dataset_100/test/*/*.jpg")
test_mask_files = glob("/home/ubuntu/dataset_100/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()

correct = 0
total = 0

for i in range(len(test_image_files)):
    
    image_correct = 0
    image_total = 0
    
    image_file = test_image_files[i]
    mask_file = test_mask_files[i]

    print(image_file)
    
    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
        print("image too small, image_file")
        continue
            
    pil_mask = Image.open(mask_file).convert('1')
    pil_mask = pil_mask.resize((int(round(pil_mask.size[0] * image_scale)), int(round(pil_mask.size[1] * image_scale))))
    mask = np.array(pil_mask)
    
    margin = window_size-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations = list(mask_locations(mask, stride, grid_shape, offsets))
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations]

    # print('len(points)', len(points))
    
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))
    
    nn_ids, _ = index.knn_query(feats, 10)
    
    for foo in nn_ids:
        nn_files = [patch_dict.get(q) for q in foo]

        for f in nn_files:

            a = image_file.split('/')[-2]
            b = f.split('/')[-2]

            # print(a, b)

            if a == b:
                image_correct += 1
                correct += 1

            image_total += 1
            total += 1

    print("score", image_correct/image_total)
    
print("final score", correct/total)

unsupervised final score 0.8863636363636364  
supervised final score: 0.8692307692307693  