In [1]:
from glob import glob
import random
import math
from tqdm import tqdm

import numpy as np
import ipyplot
from PIL import Image
from sklearn.cluster import KMeans
from itertools import compress
import pickle
import pandas as pd
import cv2
import ipyplot
import gensim
from ast import literal_eval
import pathlib

from MoCoFeatureExtractor import MoCoFeatureExtractor

cpu


### Parameters and Utility methods for extracting patches

In [3]:
version = 'c'

window_size_outer = 288
window_size_inner = 144

stride = 72

n_clusters = 100
word_format_string = '{:02d}'

walk_length = 10
walks_per_image = 100

cluster_patches_per_image = 24

image_scale = 1

feature_dim = 2048

cluster_file = f'clusters_{window_size_outer}_{window_size_inner}_{stride}_{version}.pkl'
image_cluster_grid_file = f'image_cluster_grids_{window_size_outer}_{window_size_inner}_{stride}_{version}.npy'
sequences_file = f'sequences_{window_size_outer}_{window_size_inner}_{stride}_{version}.csv'
doc2vec_file = f'doc2vec_{window_size_outer}_{window_size_inner}_{stride}_{version}.model'

cnn = MoCoFeatureExtractor(params_path='/Users/racoon/Desktop/moco_v2_800ep_pretrain.pth.tar')

image_files = glob("dataset_1000/train/*/*.jpg")

image_id_to_class = dict([(f.split('/')[-1].split('.')[0], f.split('/')[-2]) for f in image_files])

def extract_windows(frame, pos, window_size):
    windows = np.empty((len(pos), window_size, window_size, 3), dtype=np.uint8)

    for i in range(len(pos)):
        windows[i] = extract_window(frame, pos[i], window_size)

    return windows


def extract_window(frame, pos, window_size):
    half_w = window_size/2.0

    top_left = [int(round(pos[0]-half_w)), int(round(pos[1]-half_w))]
    bottom_right = [top_left[0]+window_size, top_left[1]+window_size]

    return frame[top_left[0]:bottom_right[0], top_left[1]:bottom_right[1]]



def get_rad_grid(grid_pos, rad, grid_shape):

    top_left = (grid_pos[0]-rad, grid_pos[1]-rad)

    res = []

    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1])
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)
 
    for i in range(2*rad+1):
        p = (top_left[0]+i, top_left[1]+(2*rad))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0], top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    for i in range(2*rad-1):
        p = (top_left[0]+(2*rad), top_left[1]+(i+1))
        if p[0] >= 0 and p[1] >= 0 and p[0] < grid_shape[0] and p[1] < grid_shape[1]:
            res.append(p)

    return res



def next_pos(salient_grid_positions, grid_shape, current_position):
    
    if current_position is not None:

        rad_grid = get_rad_grid(current_position, 1, grid_shape)

        # print('rad_grid', current_position, rad_grid)
        
        if len(rad_grid) == 0:
            print("frame empty?")
            
        else:
            random.shuffle(rad_grid)
            for loc in rad_grid:
                if loc in salient_grid_positions:
                    return loc
    
    return random.sample(salient_grid_positions,1)[0]

### Sample and cluster patches
This simply uses a fixed grid system. Future patch sampling methods could incorporate an intrest point detector.

In [None]:
for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):
    
    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)
    
    if image.shape[0] < window_size_outer * 2 or image.shape[1] < window_size_outer * 2 or image.shape[0] > 1024 or image.shape[1] > 1024:
        continue

    margin = window_size_outer-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    points = [(offsets[0]+y*stride+stride/2,offsets[1]+x*stride+stride/2) for y in range(grid_shape[0]) for x in range(grid_shape[1])]

    patches_outer = extract_windows(image, points, window_size_outer)
    windows_outer = patches_outer.astype(np.float64)
    
    patches_inner = extract_windows(image, points, window_size_inner)
    windows_inner = patches_inner.astype(np.float64)
    
    try:
        feats_outer = cnn.evalRGB(windows_outer)
        feats_inner = cnn.evalRGB(windows_inner)
    except:
        print("ERROR cnn.evalRGB", image, image.shape, windows_outer.shape, windows_inner.shape)
        raise

    feat_grid_outer = feats_outer.reshape((grid_shape[0], grid_shape[1], feature_dim))
    feat_grid_inner = feats_inner.reshape((grid_shape[0], grid_shape[1], feature_dim))
    
    path_parts = image_file.split('/')
    image_id = path_parts[-1].split('.')[0]
    image_class = path_parts[-2]
    
    pathlib.Path(f'feat_grids_{window_size_outer}_{window_size_inner}_{stride}_{version}/{image_class}').mkdir(parents=True, exist_ok=True)
    np.savez_compressed(f'feat_grids_{window_size_outer}_{window_size_inner}_{stride}_{version}/{image_class}/{image_id}.npz', outer=feat_grid_outer, inner=feat_grid_inner)

## Cluster Outer

In [None]:
npz_files = glob(f'/Users/racoon/Desktop/feat_grids_{window_size_outer}_{window_size_inner}_{stride}_{version}/*/*.npz')

X = []

for idx, npz_file in tqdm(enumerate(npz_files), total=len(npz_files)):

    loaded = np.load(npz_file)

    feat_grid_outer = loaded['outer']

    grid_locs = [(y,x) for y in range(feat_grid_outer.shape[0])for x in range(feat_grid_outer.shape[1])]
    
    grid_locs = random.sample(grid_locs, cluster_patches_per_image) if len(grid_locs) > cluster_patches_per_image else grid_locs
    
    for l in grid_locs:
        X.append(feat_grid_outer[l])

    

print("Clustering with KMeans: len(X)", len(X))

clusters = KMeans(n_clusters=n_clusters, verbose=False)
clusters.fit(np.array(X, dtype=np.float32))

pickle.dump(clusters, open(cluster_file, "wb"))

print("done")

100%|██████████| 15112/15112 [06:36<00:00, 38.07it/s]


Clustering with KMeans: len(X) 362688


### Generate sequence dataset with random walks

In [None]:
def get_grid_locations(grid_shape, stride, offsets, mask=None):
    
    if mask is not None:
        object_grid_locations = set()

        for y in range(grid_shape[0]):
            for x in range(grid_shape[1]):
                p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
                w = extract_window(mask, p, stride)

                if np.sum(w) >= stride * stride * 0.3:
                    object_grid_locations.add((y, x))
        
        return object_grid_locations
    
    else:
        return [(y,x) for y in range(grid_shape[0]) for x in range(grid_shape[1])]
    
    
def generate_image_cluster_grid(image_file, image_scale, clusters, feature_extractor):
    # print("generate_image_sequences", image_file)

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
        print("image too small, image_file")
        return None
            
    margin = window_size-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = get_grid_locations(grid_shape, stride, offsets)
    grid_locations_list = list(grid_locations_set)
    
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
        
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    # print(windows.shape)
    
    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))

    grid_cluster_ids = clusters.predict(feats)
        
    cluster_grid = np.full(grid_shape, -1, dtype=int)
    
    for i in range(len(grid_locations_list)):
        cluster_grid[grid_locations_list[i]] = grid_cluster_ids[i]
        
    return cluster_grid


def generate_image_sequences(image_cluster_grid, seq_count=walks_per_image):

    cluster_seqs = []
    
    grid_locations_set = set([(y,x) for y in range(image_cluster_grid.shape[0]) for x in range(image_cluster_grid.shape[1])])
    
    for i in range(seq_count):
        cluster_seq = []

        pos = None
        
        for t in range(walk_length):
            pos = next_pos(grid_locations_set, image_cluster_grid.shape, pos)
            cluster_seq.append(image_cluster_grid[pos])
            
        cluster_seqs.append([word_format_string.format(w) for w in cluster_seq])
        
    return cluster_seqs


def mask_locations(mask, stride, grid_shape, offsets):
    
    object_grid_locations = set()

    for y in range(grid_shape[0]):
        for x in range(grid_shape[1]):
            p = (offsets[0] + y * stride + 0.5 * stride, offsets[1] + x * stride + 0.5 * stride)
            w = extract_window(mask, p, stride)

            # print(np.sum(w))
            if np.sum(w) >= stride * stride * 0.3:
                object_grid_locations.add((y, x))

    return object_grid_locations

     
def generate_masked_image_sequences(image_file, mask_file, clusters, feature_extractor, seq_count=walks_per_image):

    pil_image = Image.open(image_file).convert('RGB')
    pil_image = pil_image.resize((int(round(pil_image.size[0] * image_scale)), int(round(pil_image.size[1] * image_scale))))
    image = np.array(pil_image)

    if image.shape[0] < window_size * 2 or image.shape[1] < window_size * 2:
        print("image too small, image_file")
        return None
            
    pil_mask = Image.open(mask_file).convert('1')
    pil_mask = pil_mask.resize((int(round(pil_mask.size[0] * image_scale)), int(round(pil_mask.size[1] * image_scale))))
    mask = np.array(pil_mask)
        
    margin = window_size-stride
    grid_shape = (math.floor((image.shape[0] - margin) / stride), math.floor((image.shape[1] - margin) / stride))
    offsets = (round((image.shape[0] - grid_shape[0] * stride)/2), round((image.shape[1] - grid_shape[1] * stride)/2))

    grid_locations_set = mask_locations(mask, stride, grid_shape, offsets)
    grid_locations_list = list(grid_locations_set)
    
    points = [(y*stride + stride/2 + offsets[0], x*stride + stride/2 + offsets[1]) for (y,x) in grid_locations_list]
        
    patches = extract_windows(image, points, window_size)
    windows = patches.astype(np.float64)

    feats = cnn.evalRGB(windows)
    feats = feats.reshape((windows.shape[0], feature_dim))

    grid_cluster_ids = clusters.predict(feats)
    grid_location_to_cluster_id = dict([(grid_locations_list[i], grid_cluster_ids[i]) for i in range(len(grid_locations_list))])
          
    cluster_seqs = []
    for i in range(seq_count):
        cluster_seq = []
        
        pos = None
        
        for t in range(walk_length):
            pos = next_pos(grid_locations_set, grid_shape, pos)
            cluster_seq.append(grid_location_to_cluster_id[pos])
            
        cluster_seqs.append([word_format_string.format(w) for w in cluster_seq])
  
    return cluster_seqs

In [None]:
clusters = pickle.load(open(cluster_file, "rb"))

image_cluster_grids = {}

for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):

    image_cluster_grid = generate_image_cluster_grid(image_file, image_scale, clusters, cnn)
    if image_cluster_grid is None:
        continue
    image_id = image_file.split('/')[-1].split('.')[0]
    image_cluster_grids[image_id] = image_cluster_grid    
    
np.save(image_cluster_grid_file, image_cluster_grids)   

print("done")
# print(np.load(image_cluster_grid_file, allow_pickle=True))

In [None]:
image_cluster_grids = np.load(image_cluster_grid_file, allow_pickle=True).item()

cluster_seqs = []
image_file_colummn = []

image_files = list(image_cluster_grids.keys())

for idx, image_file in tqdm(enumerate(image_files), total=len(image_files)):
    c_seqs = generate_image_sequences(image_cluster_grids[image_file])
    if c_seqs is None:
        continue
    cluster_seqs.extend(c_seqs)
    image_file_colummn.extend([image_file] * walks_per_image)

data_frame = pd.DataFrame({'words':cluster_seqs, 'file':image_file_colummn})

data_frame.to_csv(sequences_file)

print("done", len(cluster_seqs))

In [None]:
class callback(gensim.models.callbacks.CallbackAny2Vec):
    '''Callback to print loss after each epoch.'''

    def __init__(self):
        self.epoch = 0

    def on_epoch_end(self, model):
        print('epoch {}'.format(self.epoch))
        self.epoch += 1
              
def read_corpus(fname, tokens_only=False):
    data_frame = pd.read_csv(sequences_file,converters={"words": literal_eval})
    
    for index, row in data_frame.iterrows():
        if tokens_only:
            yield row['words']
        else:
            yield gensim.models.doc2vec.TaggedDocument(row['words'], [index])

train_corpus = list(read_corpus(sequences_file))
print(train_corpus[:2])

model = gensim.models.doc2vec.Doc2Vec(vector_size=256, epochs=30)
model.build_vocab(train_corpus)
model.train(train_corpus, total_examples=model.corpus_count, epochs=model.epochs, callbacks=[callback()])

model.save(doc2vec_file)

print("done")

In [None]:
clusters = pickle.load(open(cluster_file, "rb"))
model = gensim.models.doc2vec.Doc2Vec.load(doc2vec_file)

data_frame = pd.read_csv(sequences_file, converters={"words": literal_eval})

test_image_files = glob("dataset_100/test/*/*.jpg")
test_mask_files = glob("dataset_100/test/*/*.mask.png")

test_image_files.sort()
test_mask_files.sort()


correct = 0
total = 0

for i in range(len(test_image_files)):
    
    image_correct = 0
    image_total = 0
    
    image_file = test_image_files[i]
    mask_file = test_mask_files[i]

    print("test", image_file)

    c_seqs = generate_masked_image_sequences(image_file, mask_file, clusters, cnn, seq_count=100)

    if c_seqs is None:
        continue
        
    vectors = [[model.infer_vector(s)] for s in c_seqs]

    for v in vectors:
        similar = model.docvecs.most_similar(v, topn=10)
        #print('similar', similar)

        for s in similar:
            f = data_frame.loc[s[0],'file']
            
            a = image_file.split('/')[-2]
            b = image_id_to_class[f]
           
            if a == b:
                image_correct += 1
                correct += 1
           
            image_total += 1
            total += 1
            
    print("score", image_correct/image_total)
    
print("final score", correct/total)

## Results

```
96  72  1000 0.36  
96  72  100  0.45  
192 72  100  0.77
288 144 100  0.82
288 144 1000 
```