In [103]:
import pickle
import faiss

https://towardsdatascience.com/comprehensive-guide-to-approximate-nearest-neighbors-algorithms-8b94f057d6b6

In [104]:
import pandas as pd
import numpy as np
from time import time
from tqdm import tqdm
tqdm.pandas()

model_name = 'seresnext101'
fold = 2
checkpoint = '17600'
algo = 'dist_global_org'

nums = [model_name,fold,
        checkpoint,
        model_name,
        fold,algo]

In [105]:
top20 = pd.read_csv('../WC_result/{0}_{1}/out_{2}/{3}_sub_fold{4}_{5}.csv'.format(*nums), header=None)
top20 = top20.set_index(0)

In [106]:
enc = pd.read_csv('../WC_result/{0}_{1}/out_{2}/encoding_org_img.csv'.format(*nums), header=None)

enc = enc.set_index(0)

enc['embeddings'] = enc.values.tolist()

enc = enc.reset_index()

enc = enc.iloc[:, [0, 2050-1]]

enc.columns = ['img', 'embeddings']

enc.head(1)

Unnamed: 0,img,embeddings
0,PM-WWA-20180811-093.jpg,"[0.10390169, -1.9826837e-05, -0.031419944, -0...."


## Exhaustive Search Usage


In [123]:
USE_GPU = True

class ExactIndex():
    def __init__(self, vectors, labels):
        self.dimension = vectors.shape[1]
        self.vectors = np.ascontiguousarray(vectors.astype('float32'))
        self.labels = labels    
   
    def build(self):
        self.index = faiss.IndexFlatL2(self.dimension,)
        if USE_GPU:
            res = faiss.StandardGpuResources()
            self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
        self.index.add(self.vectors)
        
    def query(self, vectors, k=10):
        distances, indices = self.index.search(vectors, k) 
        # I expect only query on one vector thus the slice
        return [self.labels[i] for i in indices[0]]

In [124]:
index = ExactIndex(enc["embeddings"].apply(pd.Series).values, enc["img"].values)
index.build()

In [125]:
dat

array([[ 1.0390169e-01, -1.9826837e-05, -3.1419944e-02, ...,
        -4.3996390e-02, -1.5706321e-02,  5.4578841e-02],
       [-1.0056405e-02,  1.8141091e-04,  6.5355293e-02, ...,
        -3.1188192e-02,  1.1094182e-02, -1.6914042e-02],
       [-1.0010423e-02,  4.0318569e-06,  8.4382081e-03, ...,
         7.5220470e-03, -1.7691132e-02, -2.1372322e-02],
       ...,
       [-8.0038095e-03, -7.6144803e-05, -1.8051082e-02, ...,
        -3.2090891e-02, -3.6907820e-03,  3.1203418e-03],
       [ 1.1608690e-01,  3.2977765e-05, -3.6970368e-03, ...,
        -1.5181725e-02, -1.3473527e-02, -1.9494843e-02],
       [-1.4533627e-02, -3.1686645e-06,  1.5872167e-02, ...,
        -2.8415434e-02, -1.1559112e-02,  3.5376597e-02]], dtype=float32)

In [126]:
dat = np.ascontiguousarray(enc["embeddings"].apply(pd.Series).values).astype('float32')

In [127]:
enc.head(1)

Unnamed: 0,img,embeddings
0,PM-WWA-20180811-093.jpg,"[0.10390169, -1.9826837e-05, -0.031419944, -0...."


In [133]:
index.index.get_distance_computer()

<faiss.swigfaiss_avx2.DistanceComputer; proxy of <Swig Object of type 'faiss::DistanceComputer *' at 0x7fd7d1e231b0> >

In [128]:
index.query(np.expand_dims(dat[0], 1).reshape(1, -1), 21)[1:]

['PM-WWA-20160408-598.jpg',
 'PM-WWA-20100723-339.jpg',
 'PM-WWA-20100723-355.jpg',
 'PM-WWA-20180627-038.jpg',
 'PM-WWA-20160722-035.jpg',
 'PM-WWA-20140820-095.jpg',
 'PM-WWA-20090702-046.jpg',
 'PM-WWA-20060608-026.jpg',
 'PM-WWA-20050612-045.jpg',
 'PM-WWA-20181019-008.jpg',
 'PM-WWA-20160801-046.jpg',
 'PM-WWA-20110728-029.jpg',
 'PM-WWA-20110724-031.jpg',
 'PM-WWA-20140627-075.jpg',
 'PM-WWA-20050705-151.jpg',
 'PM-WWA-20100421-181.jpg',
 'PM-WWA-20120617-014.jpg',
 'PM-WWA-20160617-346.jpg',
 'PM-WWA-20180914-075.jpg',
 'PM-WWA-20100930-017.jpg']

In [115]:
test = enc[enc.img.isin(top20.index)]
test = test.reset_index() 
test.columns = ['id', 'img' , 'embeddings']

In [117]:
test.shape

(808, 3)

In [118]:
test[test.img.isin(top20.index)].shape

(808, 3)

In [120]:
test['top20imgs'] = test['id'].map(lambda x: index.query(np.expand_dims(dat[x], 1).reshape(1, -1), 21)[1:])

In [121]:
test.head()

Unnamed: 0,id,img,embeddings,top20imgs
0,0,PM-WWA-20180811-093.jpg,"[0.10390169, -1.9826837e-05, -0.031419944, -0....","[PM-WWA-20160408-598.jpg, PM-WWA-20100723-339...."
1,18,PM-WWA-20170710-031.jpg,"[-0.007411923000000001, 2.2409202e-05, 0.01138...","[PM-WWA-20160319-207.jpg, PM-WWA-20110724-031...."
2,19,PM-WWA-20170622-226.jpg,"[-0.009813612, -6.898617400000001e-06, 0.06699...","[PM-WWA-20170625-283.jpg, PM-WWA-20060819-009...."
3,20,PM-WWA-20180813-271.jpg,"[-0.00257763, 6.6202214e-05, -0.012052906, -0....","[PM-WWA-20060818-192.jpg, PM-WWA-20080516-102...."
4,21,PM-WWA-20180506-348.jpg,"[0.15255915, -2.6346093e-05, 0.06690723, -0.01...","[PM-WWA-20060618-010.jpg, PM-WWA-20110720-104...."


In [122]:
pd.DataFrame(test.img).join(test['top20imgs'].apply(pd.Series)).to_csv('faiss.csv', header=False, index=False)