<a href="https://colab.research.google.com/github/yongsun-yoon/academic-sentence-retriever/blob/main/02_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Search

## 1. Setup

In [None]:
!pip install -q transformers faiss-cpu

In [None]:
import faiss
import sqlite3
import easydict
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

In [None]:
cfg = easydict.EasyDict(
    basedir = '/content/drive/MyDrive/project/academic-sentence-retriever',
    topk = 30,
)

## 2. Search

In [None]:
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def encode(model, tokenizer, sentences, batch_size=16, max_length=256):
    inputs = tokenizer(sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
    outputs = model(**inputs).last_hidden_state
    embeds = mean_pooling(outputs, inputs.attention_mask)
    embeds = F.normalize(embeds, p=2, dim=1)
    return embeds

In [None]:
conn = sqlite3.connect(f'{cfg.basedir}/data.sqlite')
cursor = conn.cursor()
index = faiss.read_index(f'{cfg.basedir}/data.faiss')

cursor.execute('SELECT COUNT(*) FROM sents')
num_sents = cursor.fetchone()[0]
print(num_sents)

In [None]:
model_path = f'{cfg.basedir}/model'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
_ = model.eval().requires_grad_(False)

In [None]:
query = '성능이 향상되었다.'
query_embed = encode(model, tokenizer, query)

D, I = index.search(query_embed, cfg.topk)
indices = I[0].tolist()

where_clause = " OR ".join(["id = ?" for _ in indices])
cursor.execute(f"SELECT * FROM sents WHERE {where_clause}", indices)
results = cursor.fetchall()

In [None]:
ranking = {i:idx for idx, i in enumerate(indices)}

results = pd.DataFrame(results)
results.columns = ['id', 'sent', 'arxiv_id']
results['rank'] = results['id'].map(ranking)
results = results.sort_values('rank')

In [None]:
for row in results.itertuples():
    print(f'{row.rank:02d} | {row.arxiv_id} | {row.sent}')
    print('-'*50)