In [2]:
from glob import glob
import random
import math

import numpy as np
import ipyplot
from PIL import Image
from sklearn.cluster import KMeans
from itertools import compress
import pickle
import pandas as pd
import cv2
import ipyplot
import gensim
from ast import literal_eval
import hdbscan
# from ShufflePatchModel16 import ShufflePatchFeatureExtractor
# from VggFeatureExtractor import VggFeatureExtractor
from MoCoFeatureExtractor import MoCoFeatureExtractor


### Parameters and Utility methods for extracting patches

In [3]:
window_size = 32
stride = 24
kp_margin = 16 # keypoint detector has a margin around image where it can not find keypoints
n_clusters = 100

max_files = 1000

walk_length = 10
walks_per_image = 100
word_format_string = '{:02d}'

image_scales = [1/3]

cnn = MoCoFeatureExtractor()
#cnn = VggFeatureExtractor()
#cnn = ShufflePatchFeatureExtractor("/Users/racoon/Desktop/rotation_jigsaw_migrated_0710_0.0001_1.9522_43.62.pt")

def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]



def get_rad_grid(grid_pos, rad, grid_shape):

    top_left = (grid_pos[0]-rad, grid_pos[1]-rad)

    res = []

    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1])
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)
 
    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1]+(2*rad))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0], top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0]+(2*rad), top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    return res



def next_pos(salient_grid_positions, grid_shape, current_position):
    
    if current_position is not None:

        rad_grid = get_rad_grid(current_position, 1, grid_shape)

        # print('rad_grid', current_position, rad_grid)
        
        if len(rad_grid) == 0:
            print("frame empty?")
            
        else:
            random.shuffle(rad_grid)
            for loc in rad_grid:
                if loc in salient_grid_positions:
                    return loc
    
    return random.sample(salient_grid_positions,1)[0]

### Sample and cluster patches
This simply uses a fixed grid system. Future patch sampling methods could incorporate an intrest point detector.

In [None]:
image_files = glob("dataset_100/train/*/*.jpg")[:max_files]

X = []
#P = []

for image_scale in image_scales:
    for idx, image_file in enumerate(image_files):
        print(idx, image_scale, image_file)
        pil_image = Image.open(image_file).convert('RGB')
        pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
        image = np.array(pil_image)

        margin = max(window_size, kp_margin*2)
        grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
        offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

        points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]

        patches = extract_windows(image, points, window_size)

        #P.extend(list(patches))

        windows = patches.astype(np.float64)

        feats = cnn.evalRGB(windows)
        print('feats.shape', feats.shape, windows.shape)
        feats = feats.reshape((windows.shape[0], 2048))
        X.extend(list(feats))


print("Clustering with KMeans")
clusters = hdbscan.HDBSCAN(min_cluster_size=10)
# clusters = KMeans(n_clusters=n_clusters, verbose=False)
clusters.fit(np.array(X, dtype=np.float32))

cluster_count = len(np.unique(clusters.labels_))
print("done", cluster_count)

pickle.dump(clusters, open("clusters.pkl", "wb"))

# for i in range(cluster_count):
#     print([j for (j, x) in enumerate(clusters.labels_) if x == i])
#     patch_cluster_samples = [P[j] for (j, x) in enumerate(clusters.labels_) if x == i][:10]
#     ipyplot.plot_images(patch_cluster_samples)

0 0.3333333333333333 dataset_100/train/car/46938b4c628ce00e.jpg
feats.shape (84, 2048, 1, 1) (84, 32, 32, 3)
1 0.3333333333333333 dataset_100/train/car/8c4b9d096f6423ed.jpg
feats.shape (96, 2048, 1, 1) (96, 32, 32, 3)
2 0.3333333333333333 dataset_100/train/car/fc418b3caef440aa.jpg
feats.shape (108, 2048, 1, 1) (108, 32, 32, 3)
3 0.3333333333333333 dataset_100/train/car/eac45380074ba8c8.jpg
feats.shape (108, 2048, 1, 1) (108, 32, 32, 3)
4 0.3333333333333333 dataset_100/train/car/ea9a6d46a1279f85.jpg
feats.shape (72, 2048, 1, 1) (72, 32, 32, 3)
5 0.3333333333333333 dataset_100/train/car/9893ae3d876f9c1c.jpg
feats.shape (96, 2048, 1, 1) (96, 32, 32, 3)
6 0.3333333333333333 dataset_100/train/car/f6b73bb2536fdcb7.jpg
feats.shape (144, 2048, 1, 1) (144, 32, 32, 3)
7 0.3333333333333333 dataset_100/train/car/c1ae01ffc0c505f4.jpg
feats.shape (108, 2048, 1, 1) (108, 32, 32, 3)
8 0.3333333333333333 dataset_100/train/car/0a41cda5f44baaf6.jpg
feats.shape (96, 2048, 1, 1) (96, 32, 32, 3)
9 0.3333333

### Generate sequence dataset with random walks

In [4]:
def salient_grid_locations(image, stride, grid_shape, offsets, orb, mask = None):
    
    if mask is not None:
        object_grid_locations = set()

        for y in range(grid_shape[0]):
            for x in range(grid_shape[1]):
                p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
                w = extract_window(mask, p, stride)

                # print(np.sum(w))
                if np.sum(w) >= stride * stride * 0.3:
                    object_grid_locations.add((y, x))
                
                
    kp = orb.detect(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY), None)

    grid = set()

    for k in kp:
        p = (k.pt[1], k.pt[0]) 
        g = (int(math.floor((p[0]-offsets[0])/stride)), int(math.floor((p[1]-offsets[1])/stride)))
        
        if g[0] < 0 or g[0] >= grid_shape[0] or g[1] < 0 or g[1] >= grid_shape[1]:
            continue
            
        if mask is not None: 
            if g in object_grid_locations:
                grid.add(g)
        else:
             grid.add(g)
            
    
    return grid


def generate_image_sequences(image_file, image_scale, clusters, feature_extractor, orb, mask_file = None, seq_count=walks_per_image):
    # print("generate_image_sequences", image_file)

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    mask = None
    
    if mask_file is not None:
        pil_mask = Image.open(mask_file).convert('1')
        pil_mask = pil_mask.resize((int(round(pil_mask.size[0] * image_scale)), int(round(pil_mask.size[1] * image_scale))))
        mask = np.array(pil_mask)
        
    
    margin = max(window_size, kp_margin*2)
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = salient_grid_locations(image, stride, grid_shape, offsets, orb, mask)
    grid_locations_list = list(grid_locations_set)
    
    # print(grid_shape, grid_locations_list)
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
    
    # print('salient grid locations', grid_shape, len(grid_locations_list), 'of', grid_shape[0] * grid_shape[1])
        
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], 512))

    grid_cluster_ids = clusters.predict(feats)

    grid_location_to_cluster_id = dict([(grid_locations_list[i], grid_cluster_ids[i]) for i in range(len(grid_locations_list))])
    #grid_location_to_patches = dict([(grid_locations_list[i], patches[i]) for i in range(len(grid_locations_list))])
        
    cluster_seqs = []
    point_seqs = []
    
    for i in range(seq_count):
        cluster_seq = []
        point_seq = []
        
        pos = None
        
        for t in range(walk_length):
            pos = next_pos(grid_locations_set, grid_shape, pos)
            cluster_seq.append(grid_location_to_cluster_id[pos])
            point_seq.append((pos[0]*stride + stride/2 + offsets[0], pos[1]*stride + stride/2 + offsets[1]))
            
        cluster_seqs.append([word_format_string.format(w) for w in cluster_seq])
        point_seqs.append(point_seq)
        #print(seqs[-1])
        
        #ipyplot.plot_images(patch_seq, seq)
        
    return cluster_seqs, point_seqs


In [5]:
image_files = glob("dataset_100/train/*/*.jpg")[:max_files]
image_files.sort()

clusters = pickle.load(open("clusters.pkl", "rb"))
orb = cv2.ORB_create(nfeatures=100000, fastThreshold=7)


cluster_seqs = []
point_seqs = []
image_file_colummn = []
image_scale_colummn = []

for image_scale in image_scales:
    for idx, image_file in enumerate(image_files):
        print(idx, image_file)
        c_seqs, p_seqs = generate_image_sequences(image_file, image_scale, clusters, cnn, orb)
        cluster_seqs.extend(c_seqs)
        point_seqs.extend(p_seqs)
        image_file_colummn.extend([image_file] * walks_per_image)
        image_scale_colummn.extend([image_scale] * walks_per_image)
    
    
data_frame = pd.DataFrame({'words':cluster_seqs, 'file':image_file_colummn, 'points': point_seqs, 'scale': image_scale_colummn})

data_frame.to_csv('sequences.csv')

print("done")

0 dataset_100/train/airplane/0276aef954d36f18.jpg
1 dataset_100/train/airplane/033045b8afa879ec.jpg
2 dataset_100/train/airplane/046ba24119dc2170.jpg
3 dataset_100/train/airplane/06c3a9a0ebc0af5d.jpg
4 dataset_100/train/airplane/07393565a9b3dd6f.jpg
5 dataset_100/train/airplane/074e5b1be1e568f6.jpg
6 dataset_100/train/airplane/0b9f3a3b87af742f.jpg
7 dataset_100/train/airplane/11b50d4333e1e68c.jpg
8 dataset_100/train/airplane/13ab2b3daff530fe.jpg
9 dataset_100/train/airplane/177a89876068beb6.jpg
10 dataset_100/train/airplane/1a2cdeeba6547bab.jpg
11 dataset_100/train/airplane/1a8c1a1a20e5a8ef.jpg
12 dataset_100/train/airplane/1bbbc7d31d0e1225.jpg
13 dataset_100/train/airplane/1bbf78b691f9cdad.jpg
14 dataset_100/train/airplane/1f46a9399acfb6f3.jpg
15 dataset_100/train/airplane/22cdb9abd3ac400c.jpg
16 dataset_100/train/airplane/2f118b9b64e097ab.jpg
17 dataset_100/train/airplane/38d146c322d738ff.jpg
18 dataset_100/train/airplane/3976b08facf36cc2.jpg
19 dataset_100/train/airplane/3aad6b9135f

In [6]:
class callback(gensim.models.callbacks.CallbackAny2Vec):
    '''Callback to print loss after each epoch.'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_end(self, model):
        print('epoch {}'.format(self.epoch))
        self.epoch += 1
              
def read_corpus(fname, tokens_only=False):
    data_frame = pd.read_csv('sequences.csv',converters={"words": literal_eval, "points": literal_eval})
    
    for index, row in data_frame.iterrows():
        if tokens_only:
            yield row['words']
        else:
            yield gensim.models.doc2vec.TaggedDocument(row['words'], [index])

train_corpus = list(read_corpus('sequences.csv'))
print(train_corpus[:2])

model = gensim.models.doc2vec.Doc2Vec(vector_size=50, epochs=40)
model.build_vocab(train_corpus)
model.train(train_corpus, total_examples=model.corpus_count, epochs=model.epochs, callbacks=[callback()])

model.save('doc2vec.model')

print("done")

[TaggedDocument(words=['045', '045', '045', '047', '047', '047', '045', '064', '037', '036'], tags=[0]), TaggedDocument(words=['083', '064', '083', '074', '083', '045', '083', '005', '036', '005'], tags=[1])]
epoch 0
epoch 1
epoch 2
epoch 3
epoch 4
epoch 5
epoch 6
epoch 7
epoch 8
epoch 9
epoch 10
epoch 11
epoch 12
epoch 13
epoch 14
epoch 15
epoch 16
epoch 17
epoch 18
epoch 19
epoch 20
epoch 21
epoch 22
epoch 23
epoch 24
epoch 25
epoch 26
epoch 27
epoch 28
epoch 29
epoch 30
epoch 31
epoch 32
epoch 33
epoch 34
epoch 35
epoch 36
epoch 37
epoch 38
epoch 39
done


In [8]:
clusters = pickle.load(open("clusters.pkl", "rb"))
orb = cv2.ORB_create(nfeatures=100000, fastThreshold=7)
model = gensim.models.doc2vec.Doc2Vec.load('doc2vec.model')

data_frame = pd.read_csv('sequences.csv',converters={"words": literal_eval, "points": literal_eval})

test_image_files = glob("dataset_100/test/*/*.jpg")
test_mask_files = glob("dataset_100/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()


for image_scale in image_scales:
    
    correct = 0
    total = 0

    for i in range(len(test_image_files)):

        image_file = test_image_files[i]
        mask_file = test_mask_files[i]

        print("test", image_file)

        c_seqs, p_seqs = generate_image_sequences(image_file, image_scale, clusters, cnn, orb, mask_file = mask_file, seq_count=100)
        #print('seqs', seqs)

        vectors = [[model.infer_vector(s)] for s in c_seqs]
        #print('vectors', vectors)

        for v in vectors:
            similar = model.docvecs.most_similar(v, topn=10)
            #print('similar', similar)

            for s in similar:
                f = data_frame.loc[s[0],'file']
                if '/airplane/' in image_file and '/airplane/' in f:
                    correct += 1
                elif '/car/' in image_file and '/car/' in f:
                    correct += 1
                elif '/horse/' in image_file and '/horse/' in f:
                    correct += 1

                #print("similar to", f)

                total += 1
    
    print("score", correct/total)

test dataset_100/test/airplane/35b11a04c24db20c.jpg
test dataset_100/test/airplane/4ad0f079b979be5d.jpg
test dataset_100/test/airplane/839ce813ca97084c.jpg
test dataset_100/test/airplane/93b5bf58149adefd.jpg
test dataset_100/test/airplane/9dc879c35a26d2d3.jpg
test dataset_100/test/airplane/a48f1d15812036fa.jpg
test dataset_100/test/airplane/b6ac22d7db1769ee.jpg
test dataset_100/test/airplane/d5422871fd63b8b8.jpg
test dataset_100/test/airplane/e95bc413d4b748ba.jpg
test dataset_100/test/airplane/fbe835c5944f93e5.jpg
test dataset_100/test/car/455c29cd8db5b225.jpg
test dataset_100/test/car/56d1d8aca15ae219.jpg
test dataset_100/test/car/89297009b1d18663.jpg
test dataset_100/test/car/9bac3f90244fef7e.jpg
test dataset_100/test/car/a051a80f600fd919.jpg
test dataset_100/test/car/b30bc1fe86057942.jpg
test dataset_100/test/car/cc0c6a5753fbe006.jpg
test dataset_100/test/car/cdbfe2973b6fdaf2.jpg
test dataset_100/test/car/d12414ad4d3e845e.jpg
test dataset_100/test/car/d16cb785f98e3d1c.jpg
test datas