In [1]:
from glob import glob
import random
import math
from tqdm import tqdm

import numpy as np
import ipyplot
from PIL import Image
from sklearn.cluster import KMeans
from itertools import compress
import pickle
import pandas as pd
import cv2
import ipyplot
import gensim
from ast import literal_eval

# from ShufflePatchModel16 import ShufflePatchFeatureExtractor
# from VggFeatureExtractor import VggFeatureExtractor
from MoCoFeatureExtractor import MoCoFeatureExtractor


cuda


In [18]:
train = glob("dataset_1000/train/*/*.jpg")
data_frame = pd.DataFrame({'train':train})
data_frame.to_csv('dataset_1000_train.csv')

### Parameters and Utility methods for extracting patches

In [9]:
window_size = 96
stride = 72
kp_margin = 16 # keypoint detector has a margin around image where it can not find keypoints
n_clusters = 1000

walk_length = 10
walks_per_image = 100
word_format_string = '{:03d}'

cluster_patches_per_image = 20

image_scales = [1]

feature_dim = 2048
cluster_file = f'clusters_{window_size}_{stride}_{n_clusters}.pkl'
image_cluster_frid_file = f'image_cluster_grids_{window_size}_{stride}_{n_clusters}.npy'

cnn = MoCoFeatureExtractor(params_path='/home/ubuntu/moco_v2_800ep_pretrain.pth.tar')

image_files = glob("/home/ubuntu/dataset_1000/train/*/*.jpg")[:10]


def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]



def get_rad_grid(grid_pos, rad, grid_shape):

    top_left = (grid_pos[0]-rad, grid_pos[1]-rad)

    res = []

    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1])
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)
 
    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1]+(2*rad))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0], top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0]+(2*rad), top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    return res



def next_pos(salient_grid_positions, grid_shape, current_position):
    
    if current_position is not None:

        rad_grid = get_rad_grid(current_position, 1, grid_shape)

        # print('rad_grid', current_position, rad_grid)
        
        if len(rad_grid) == 0:
            print("frame empty?")
            
        else:
            random.shuffle(rad_grid)
            for loc in rad_grid:
                if loc in salient_grid_positions:
                    return loc
    
    return random.sample(salient_grid_positions,1)[0]

### Sample and cluster patches
This simply uses a fixed grid system. Future patch sampling methods could incorporate an intrest point detector.

In [None]:
X = []

for image_scale in image_scales:
    for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):
        
        pil_image = Image.open(image_file).convert('RGB')
        pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
        image = np.array(pil_image)

        if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
            continue
            
        margin = max(window_size, kp_margin*2)
        grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
        offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

        points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]
        
        if len(points) > cluster_patches_per_image:
            points = random.sample(points, cluster_patches_per_image)
        
        patches = extract_windows(image, points, window_size)

        windows = patches.astype(np.float64)

        try:
            feats = cnn.evalRGB(windows)
        except:
            print("ERROR windows.shape", windows.shape)
            raise
            
        feats = feats.reshape((windows.shape[0], feature_dim))
        X.extend(list(feats))


print("Clustering with KMeans: len(X)", len(X))

clusters = KMeans(n_clusters=n_clusters, verbose=False)
clusters.fit(np.array(X, dtype=np.float32))

pickle.dump(clusters, open(cluster_file, "wb"))

print("done")

### Generate sequence dataset with random walks

In [15]:
def get_grid_locations(grid_shape, stride, offsets, mask=None):
    
    if mask is not None:
        object_grid_locations = set()

        for y in range(grid_shape[0]):
            for x in range(grid_shape[1]):
                p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
                w = extract_window(mask, p, stride)

                if np.sum(w) >= stride * stride * 0.3:
                    object_grid_locations.add((y, x))
        
        return object_grid_locations
    
    else:
        return [(y,x) for y in range(grid_shape[0]) for x in range(grid_shape[1])]
    
def generate_image_cluster_grid(image_file, image_scale, clusters, feature_extractor):
    # print("generate_image_sequences", image_file)

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
        print("image too small, image_file")
        return None
            
    margin = max(window_size, kp_margin*2)
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = get_grid_locations(grid_shape, stride, offsets)
    grid_locations_list = list(grid_locations_set)
    
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
        
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    # print(windows.shape)
    
    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))

    grid_cluster_ids = clusters.predict(feats)
        
    cluster_grid = np.full(grid_shape, -1, dtype=int)
    
    for i in range(len(grid_locations_list)):
        cluster_grid[grid_locations_list[i]] = grid_cluster_ids[i]
        
    return cluster_grid

def generate_image_sequences(image_file, image_scale, clusters, feature_extractor, mask_file = None, seq_count=walks_per_image):
    # print("generate_image_sequences", image_file)

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
        print("image too small, image_file")
        return None, None
            
    mask = None
    
    if mask_file is not None:
        pil_mask = Image.open(mask_file).convert('1')
        pil_mask = pil_mask.resize((int(round(pil_mask.size[0] * image_scale)), int(round(pil_mask.size[1] * image_scale))))
        mask = np.array(pil_mask)
        
    
    margin = max(window_size, kp_margin*2)
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = get_grid_locations(grid_shape, stride, offsets, mask)
    grid_locations_list = list(grid_locations_set)
    
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
        
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    # print(windows.shape)
    
    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))

    grid_cluster_ids = clusters.predict(feats)

    grid_location_to_cluster_id = dict([(grid_locations_list[i], grid_cluster_ids[i]) for i in range(len(grid_locations_list))])
        
    cluster_seqs = []
    point_seqs = []
    
    for i in range(seq_count):
        cluster_seq = []
        point_seq = []
        
        pos = None
        
        for t in range(walk_length):
            pos = next_pos(grid_locations_set, grid_shape, pos)
            cluster_seq.append(grid_location_to_cluster_id[pos])
            point_seq.append((pos[0]*stride + stride/2 + offsets[0], pos[1]*stride + stride/2 + offsets[1]))
            
        cluster_seqs.append([word_format_string.format(w) for w in cluster_seq])
        point_seqs.append(point_seq)
        
    return cluster_seqs, point_seqs


In [21]:
clusters = pickle.load(open(cluster_file, "rb"))

image_cluster_grids = {}

for image_scale in image_scales:
    for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):

        image_cluster_grid = generate_image_cluster_grid(image_file, image_scale, clusters, cnn)
        if image_cluster_grid is None:
            continue
        image_id = image_file.split('/')[-1].split('.')[0]
        image_cluster_grids[image_id] = image_cluster_grid    
    
np.save(image_cluster_frid_file, image_cluster_grids)   

print("done")
print(np.load(image_cluster_frid_file, allow_pickle=True))

100%|██████████| 10/10 [00:01<00:00,  6.78it/s]

done
{'763e6730eeb67b95': array([[286, 837, 837, 837, 202, 687, 687, 687, 837, 286,  91,  91],
       [ 91, 837, 799, 799, 202, 202, 286, 202, 202,  91, 837, 837],
       [ 91, 677,  30,  88,  31, 837,  91, 286, 202, 687, 687, 837],
       [ 91, 837, 401, 803, 996, 965, 286, 837, 599,  18, 266,  91],
       [ 91, 837, 885, 885, 965, 233, 523, 632, 599, 202, 806, 806],
       [806,  31, 885, 965, 601,  42, 632, 622,  91, 202, 286, 806],
       [806,  61,  61, 632,   8,   8, 601, 286,  91, 806, 286, 286],
       [286, 837, 275, 879, 171, 601,  15, 286, 806, 806, 286, 806]]), '78f43b2cda7e7083': array([[171, 599, 940, 940, 940, 940, 837, 286, 837, 837, 837, 286],
       [599, 709, 171, 940, 171, 171, 940, 171, 233, 996, 996, 837],
       [503, 233, 599, 599, 599, 940, 171, 171, 502, 996, 996, 996],
       [ 17, 263, 940, 601, 996, 266, 233, 996, 996, 242, 996, 996],
       [474, 996, 824, 601, 996, 996, 996, 996, 996, 996, 996, 996],
       [ 91, 286, 502, 774, 774, 774, 996, 433, 242, 99




In [8]:
clusters = pickle.load(open(cluster_file, "rb"))

cluster_seqs = []
point_seqs = []
image_file_colummn = []
image_scale_colummn = []

for image_scale in image_scales:
    for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):

        c_seqs, p_seqs = generate_image_sequences(image_file, image_scale, clusters, cnn)
        if c_seqs is None:
            continue
        cluster_seqs.extend(c_seqs)
        point_seqs.extend(p_seqs)
        image_file_colummn.extend([image_file] * walks_per_image)
        image_scale_colummn.extend([image_scale] * walks_per_image)
    
    
data_frame = pd.DataFrame({'words':cluster_seqs, 'file':image_file_colummn, 'points': point_seqs, 'scale': image_scale_colummn})

data_frame.to_csv('sequences.csv')

print("done")

 60%|█████▉    | 9245/15537 [24:48<16:52,  6.21it/s]  


KeyboardInterrupt: 

In [None]:
class callback(gensim.models.callbacks.CallbackAny2Vec):
    '''Callback to print loss after each epoch.'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_end(self, model):
        print('epoch {}'.format(self.epoch))
        self.epoch += 1
              
def read_corpus(fname, tokens_only=False):
    data_frame = pd.read_csv('sequences.csv',converters={"words": literal_eval, "points": literal_eval})
    
    for index, row in data_frame.iterrows():
        if tokens_only:
            yield row['words']
        else:
            yield gensim.models.doc2vec.TaggedDocument(row['words'], [index])

train_corpus = list(read_corpus('sequences.csv'))
print(train_corpus[:2])

model = gensim.models.doc2vec.Doc2Vec(vector_size=50, epochs=40)
model.build_vocab(train_corpus)
model.train(train_corpus, total_examples=model.corpus_count, epochs=model.epochs, callbacks=[callback()])

model.save('doc2vec.model')

print("done")

In [None]:
clusters = pickle.load(open("clusters.pkl", "rb"))
orb = cv2.ORB_create(nfeatures=100000, fastThreshold=7)
model = gensim.models.doc2vec.Doc2Vec.load('doc2vec.model')

data_frame = pd.read_csv('sequences.csv',converters={"words": literal_eval, "points": literal_eval})

test_image_files = glob("/data/dataset_100/test/*/*.jpg")
test_mask_files = glob("/data/dataset_100/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()


for image_scale in image_scales:
    
    correct = 0
    total = 0

    for i in range(len(test_image_files)):

        image_file = test_image_files[i]
        mask_file = test_mask_files[i]

        print("test", image_file)

        c_seqs, p_seqs = generate_image_sequences(image_file, image_scale, clusters, cnn, orb, mask_file = mask_file, seq_count=100)
        #print('seqs', seqs)

        vectors = [[model.infer_vector(s)] for s in c_seqs]
        #print('vectors', vectors)

        for v in vectors:
            similar = model.docvecs.most_similar(v, topn=10)
            #print('similar', similar)

            for s in similar:
                f = data_frame.loc[s[0],'file']
                if '/airplane/' in image_file and '/airplane/' in f:
                    correct += 1
                elif '/car/' in image_file and '/car/' in f:
                    correct += 1
                elif '/horse/' in image_file and '/horse/' in f:
                    correct += 1

                #print("similar to", f)

                total += 1
    
    print("score", correct/total)