In [None]:
#| default_exp representation.search

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

In [None]:
#| export
import torch, hnswlib, numpy as np, torch.nn.functional as F
from typing import Optional, Union

from xcai.core import *

## IndexSearch

In [None]:
#| export
class IndexSearch:

    def __init__(self, 
                 index:Optional[hnswlib.Index]=None, 
                 space:Optional[str]='cosine', 
                 efc:Optional[int]=200, 
                 m:Optional[int]=16, 
                 efs:Optional[int]=50, 
                 n_bm:Optional[int]=50, 
                 n_threads:Optional[int]=84):
        store_attr('index,space,efc,m,efs,n_bm,n_threads')

    def build(self, data:Optional[Union[torch.Tensor,np.array]], info:Optional[Union[torch.Tensor,np.array]]=None):
        if info is None: info = np.arange(data.shape[0])
        if data.shape[0] != info.shape[0]: 
            raise ValueError(f'`data`({data.shape[0]}) and `info`({info.shape[0]}) should have same size.')
            
        if isinstance(data, torch.Tensor): data = data.cpu()
        if isinstance(info, torch.Tensor): info = info.cpu()
            
        self.index = hnswlib.Index(space=self.space, dim=data.shape[1])
        self.index.init_index(max_elements=data.shape[0], ef_construction=self.efc, M=self.m)
        self.index.add_items(data, info, num_threads=self.n_threads)
        self.index.set_ef(self.efs)

    def proc(self, inputs:Optional[Union[torch.Tensor,np.array]], n_bm:Optional[int]=None):
        n_bm = self.n_bm if n_bm is None else n_bm
        n_bm = self.efs if n_bm > self.efs else n_bm
        if isinstance(inputs, torch.Tensor): inputs = inputs.cpu()
        info, sc = self.index.knn_query(inputs, k=n_bm)
        info, sc, ptr = torch.tensor(info.reshape(-1).astype(np.int64)), torch.tensor(sc.reshape(-1)), torch.full((inputs.shape[0],), n_bm)
        return {'info2data_idx':info, 'info2data_score':1.0-sc, 'info2data_data2ptr':ptr}
        

#### Example

In [None]:
n,dim = 10000,128
data, info = np.float32(np.random.random((n, dim))), torch.arange(n)

data = torch.rand((n, dim)).to('cuda')

In [None]:
index = IndexSearch()
index.build(data, info)

In [None]:
output = index.proc(data)

In [None]:
output['info2data_idx'].shape, output['info2data_score'].shape, output['info2data_data2ptr'].shape

(torch.Size([500000]), torch.Size([500000]), torch.Size([10000]))

In [None]:
output['info2data_score'][:10]

tensor([1.0000, 0.8246, 0.8245, 0.8227, 0.8227, 0.8207, 0.8207, 0.8206, 0.8205,
        0.8205])

## BruteForceSearch

In [None]:
#| export
class BruteForceSearch:
    
    def __init__(self, 
                 index:Optional[torch.Tensor]=None, 
                 n_bm:Optional[int]=50):
        store_attr('index,n_bm')
        
    def build(self, data:Optional[torch.Tensor], info:Optional[torch.Tensor]=None):
        if info is not None and data.shape[0] != info.shape[0]: 
            raise ValueError(f'`data`({data.shape[0]}) and `info`({info.shape[0]}) should have same size.')
        self.index = (F.normalize(data, dim=1), info)
    
    def proc(self, inputs:Optional[torch.Tensor], n_bm:Optional[int]=None):
        store_attr('n_bm', is_none=False)
        index, info = self.index
        inputs, n_bm = F.normalize(inputs, dim=1), min(index.shape[0], self.n_bm)
        
        sc, idx = torch.topk(inputs@index.T, n_bm, dim=1, largest=True)
        if info is None: info = idx
        else: info = info.unsqueeze(0).expand((idx.shape[0],-1)).gather(1, idx)
            
        info, sc, ptr = info.reshape(-1), sc.reshape(-1), torch.full((inputs.shape[0],), n_bm, device=inputs.device)
        return {'info2data_idx':info, 'info2data_score':sc, 'info2data_data2ptr':ptr}
        

### Example

In [None]:
n,dim = 10000,128
data = torch.rand((n, dim)).to('cuda')

In [None]:
index = BruteForceSearch()
index.build(data)

In [None]:
output = index.proc(data)

In [None]:
output['info2data_score'][:10]

tensor([1.0000, 0.8248, 0.8246, 0.8245, 0.8227, 0.8227, 0.8222, 0.8221, 0.8213,
        0.8207], device='cuda:0')