<a href="https://colab.research.google.com/github/yongsun-yoon/academic-sentence-retriever/blob/main/01_register_paper.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Register paper

## 1. Setup

In [None]:
!pip install -q PyPDF2 transformers faiss-cpu

In [None]:
import os
import re
import nltk
import faiss
import PyPDF2
import sqlite3
import easydict
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel

nltk.download('punkt')

In [None]:
cfg = easydict.EasyDict(
    basedir = '/content/drive/MyDrive/project/academic-sentence-retriever',
    model_name = 'yongsun-yoon/bilingual-sentence-embedder-mMiniLMv2-L6-H384'
)

## 2. Read PDF

In [None]:
def replace_newline(text):
    text = list(text)
    for i, c in enumerate(text):
        if c != '\n': continue
        
        if text[i-1] == '-':
            text[i-1] = ''
            text[i] = ''
        else:
            text[i] = ' '
    text = ''.join(text)
    return text


def extract_sentences(reader):
    sentences = []
    for page in reader.pages:
        text = page.extract_text().strip()
        text = replace_newline(text)
        sentences += nltk.sent_tokenize(text)
    return sentences

In [None]:
arxiv_id = '2109.06349'
pdf_path = f'{cfg.basedir}/papers/{arxiv_id}.pdf'
assert os.path.exists(pdf_path)

In [None]:
pdf = open(pdf_path, 'rb')
reader = PyPDF2.PdfReader(pdf)
sentences = extract_sentences(reader)
print(len(sentences))

## 3. SQLite

In [None]:
conn = sqlite3.connect(f'{cfg.basedir}/data.sqlite')
cursor = conn.cursor()

# cursor.execute('DROP TABLE IF EXISTS sents')
cursor.execute("CREATE TABLE IF NOT EXISTS sents (id integer PRIMARY KEY, sent text, arxiv_id text)")

In [None]:
cursor.execute('SELECT COUNT(*) FROM sents')
rowid = cursor.fetchone()[0]
print(rowid)

In [None]:
cursor.execute(f'SELECT COUNT(*) FROM sents WHERE arxiv_id = {arxiv_id} GROUP BY arxiv_id')
res = cursor.fetchone()
assert res is None

In [None]:
inputs = [(rowid+i, s, arxiv_id) for i, s in enumerate(sentences)]
cursor.executemany("INSERT INTO sents(id, sent, arxiv_id) VALUES(?,?,?)", inputs)
conn.commit()
conn.close()

## 4. FAISS

In [None]:
def mean_pooling(token_embeddings, attention_mask):
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)

def encode(model, tokenizer, sentences, batch_size=16, max_length=256):
    embeds = []
    for i in tqdm(range(0, len(sentences), batch_size)):
        batch_sentences = sentences[i:i+batch_size]
        batch_inputs = tokenizer(batch_sentences, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
        batch_outputs = model(**batch_inputs).last_hidden_state
        batch_embeds = mean_pooling(batch_outputs, batch_inputs.attention_mask)
        batch_embeds = F.normalize(batch_embeds, p=2, dim=1)
        embeds.append(batch_embeds)
    embeds = torch.cat(embeds, dim=0)
    return embeds

In [None]:
model_path = f'{cfg.basedir}/model'
if not os.path.exists(model_path):
    !apt-get install git-lfs -y
    !git-lfs install
    !git clone https://huggingface.co/yongsun-yoon/bilingual-sentence-embedder-mMiniLMv2-L6-H384 "{cfg.basedir}/model"
    print('clone from Huggingface')
    
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True)
_ = model.eval().requires_grad_(False)

In [None]:
embeds = encode(model, tokenizer, sentences)
embeds.shape

In [None]:
index_path = f'{cfg.basedir}/data.faiss'
if os.path.exists(index_path):
    index = faiss.read_index(index_path)
    print('load existed index.')
else:
    index = faiss.IndexFlatL2(embeds.shape[-1])
    print('create new index.')

In [None]:
index.add(embeds)
faiss.write_index(index, index_path)